summaryrefslogtreecommitdiff
path: root/silx/gui
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui')
-rw-r--r--silx/gui/__init__.py29
-rw-r--r--silx/gui/_glutils/Context.py63
-rw-r--r--silx/gui/_glutils/FramebufferTexture.py164
-rw-r--r--silx/gui/_glutils/Program.py202
-rw-r--r--silx/gui/_glutils/Texture.py308
-rw-r--r--silx/gui/_glutils/VertexBuffer.py266
-rw-r--r--silx/gui/_glutils/__init__.py41
-rw-r--r--silx/gui/_glutils/font.py152
-rw-r--r--silx/gui/_glutils/gl.py165
-rw-r--r--silx/gui/_glutils/utils.py70
-rw-r--r--silx/gui/_utils.py102
-rw-r--r--silx/gui/console.py214
-rw-r--r--silx/gui/data/ArrayTableModel.py610
-rw-r--r--silx/gui/data/ArrayTableWidget.py490
-rw-r--r--silx/gui/data/DataViewer.py464
-rw-r--r--silx/gui/data/DataViewerFrame.py186
-rw-r--r--silx/gui/data/DataViewerSelector.py153
-rw-r--r--silx/gui/data/DataViews.py988
-rw-r--r--silx/gui/data/Hdf5TableView.py414
-rw-r--r--silx/gui/data/NXdataWidgets.py523
-rw-r--r--silx/gui/data/NumpyAxesSelector.py468
-rw-r--r--silx/gui/data/RecordTableView.py405
-rw-r--r--silx/gui/data/TextFormatter.py222
-rw-r--r--silx/gui/data/__init__.py35
-rw-r--r--silx/gui/data/setup.py41
-rw-r--r--silx/gui/data/test/__init__.py45
-rw-r--r--silx/gui/data/test/test_arraywidget.py320
-rw-r--r--silx/gui/data/test/test_dataviewer.py281
-rw-r--r--silx/gui/data/test/test_numpyaxesselector.py152
-rw-r--r--silx/gui/data/test/test_textformatter.py94
-rw-r--r--silx/gui/fit/BackgroundWidget.py530
-rw-r--r--silx/gui/fit/FitConfig.py540
-rw-r--r--silx/gui/fit/FitWidget.py727
-rw-r--r--silx/gui/fit/FitWidgets.py559
-rw-r--r--silx/gui/fit/Parameters.py882
-rw-r--r--silx/gui/fit/__init__.py28
-rw-r--r--silx/gui/fit/setup.py43
-rw-r--r--silx/gui/fit/test/__init__.py43
-rw-r--r--silx/gui/fit/test/testBackgroundWidget.py83
-rw-r--r--silx/gui/fit/test/testFitConfig.py95
-rw-r--r--silx/gui/fit/test/testFitWidget.py135
-rw-r--r--silx/gui/hdf5/Hdf5HeaderView.py192
-rw-r--r--silx/gui/hdf5/Hdf5Item.py421
-rw-r--r--silx/gui/hdf5/Hdf5LoadingItem.py68
-rw-r--r--silx/gui/hdf5/Hdf5Node.py210
-rw-r--r--silx/gui/hdf5/Hdf5TreeModel.py581
-rw-r--r--silx/gui/hdf5/Hdf5TreeView.py204
-rw-r--r--silx/gui/hdf5/NexusSortFilterProxyModel.py152
-rw-r--r--silx/gui/hdf5/__init__.py44
-rw-r--r--silx/gui/hdf5/_utils.py247
-rw-r--r--silx/gui/hdf5/setup.py41
-rw-r--r--silx/gui/hdf5/test/__init__.py39
-rw-r--r--silx/gui/hdf5/test/_mock.py130
-rw-r--r--silx/gui/hdf5/test/test_hdf5.py480
-rw-r--r--silx/gui/icons.py360
-rw-r--r--silx/gui/plot/AlphaSlider.py300
-rw-r--r--silx/gui/plot/ColorBar.py790
-rw-r--r--silx/gui/plot/ColormapDialog.py506
-rw-r--r--silx/gui/plot/Colors.py359
-rw-r--r--silx/gui/plot/CurvesROIWidget.py975
-rw-r--r--silx/gui/plot/ImageView.py860
-rw-r--r--silx/gui/plot/Interaction.py300
-rw-r--r--silx/gui/plot/LegendSelector.py1087
-rw-r--r--silx/gui/plot/MPLColormap.py1062
-rw-r--r--silx/gui/plot/MaskToolsWidget.py615
-rw-r--r--silx/gui/plot/Plot.py2925
-rw-r--r--silx/gui/plot/PlotActions.py1386
-rw-r--r--silx/gui/plot/PlotEvents.py166
-rw-r--r--silx/gui/plot/PlotInteraction.py1493
-rw-r--r--silx/gui/plot/PlotToolButtons.py280
-rw-r--r--silx/gui/plot/PlotTools.py313
-rw-r--r--silx/gui/plot/PlotWidget.py267
-rw-r--r--silx/gui/plot/PlotWindow.py766
-rw-r--r--silx/gui/plot/Profile.py741
-rw-r--r--silx/gui/plot/ProfileMainWindow.py99
-rw-r--r--silx/gui/plot/ScatterMaskToolsWidget.py529
-rw-r--r--silx/gui/plot/StackView.py1033
-rw-r--r--silx/gui/plot/_BaseMaskToolsWidget.py1138
-rw-r--r--silx/gui/plot/__init__.py71
-rw-r--r--silx/gui/plot/_utils/__init__.py104
-rw-r--r--silx/gui/plot/_utils/panzoom.py156
-rw-r--r--silx/gui/plot/_utils/setup.py42
-rw-r--r--silx/gui/plot/_utils/test/__init__.py41
-rw-r--r--silx/gui/plot/_utils/test/test_ticklayout.py78
-rw-r--r--silx/gui/plot/_utils/ticklayout.py224
-rw-r--r--silx/gui/plot/backends/BackendBase.py474
-rw-r--r--silx/gui/plot/backends/BackendMatplotlib.py821
-rw-r--r--silx/gui/plot/backends/BackendOpenGL.py1631
-rw-r--r--silx/gui/plot/backends/ModestImage.py174
-rw-r--r--silx/gui/plot/backends/__init__.py29
-rw-r--r--silx/gui/plot/backends/_matplotlib.py64
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py1317
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotFrame.py1039
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotImage.py707
-rw-r--r--silx/gui/plot/backends/glutils/GLSupport.py192
-rw-r--r--silx/gui/plot/backends/glutils/GLText.py222
-rw-r--r--silx/gui/plot/backends/glutils/GLTexture.py239
-rw-r--r--silx/gui/plot/backends/glutils/PlotImageFile.py149
-rw-r--r--silx/gui/plot/backends/glutils/__init__.py44
-rw-r--r--silx/gui/plot/items/__init__.py43
-rw-r--r--silx/gui/plot/items/core.py839
-rw-r--r--silx/gui/plot/items/curve.py192
-rw-r--r--silx/gui/plot/items/histogram.py288
-rw-r--r--silx/gui/plot/items/image.py385
-rw-r--r--silx/gui/plot/items/marker.py241
-rw-r--r--silx/gui/plot/items/scatter.py169
-rw-r--r--silx/gui/plot/items/shape.py121
-rw-r--r--silx/gui/plot/setup.py47
-rw-r--r--silx/gui/plot/test/__init__.py71
-rw-r--r--silx/gui/plot/test/testAlphaSlider.py221
-rw-r--r--silx/gui/plot/test/testColorBar.py240
-rw-r--r--silx/gui/plot/test/testColormapDialog.py68
-rw-r--r--silx/gui/plot/test/testColors.py94
-rw-r--r--silx/gui/plot/test/testCurvesROIWidget.py153
-rw-r--r--silx/gui/plot/test/testInteraction.py89
-rw-r--r--silx/gui/plot/test/testLegendSelector.py143
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py295
-rw-r--r--silx/gui/plot/test/testPlot.py633
-rw-r--r--silx/gui/plot/test/testPlotInteraction.py167
-rw-r--r--silx/gui/plot/test/testPlotTools.py203
-rw-r--r--silx/gui/plot/test/testPlotWidget.py967
-rw-r--r--silx/gui/plot/test/testPlotWindow.py138
-rw-r--r--silx/gui/plot/test/testProfile.py183
-rw-r--r--silx/gui/plot/test/testScatterMaskToolsWidget.py313
-rw-r--r--silx/gui/plot/test/testStackView.py209
-rw-r--r--silx/gui/plot3d/Plot3DActions.py362
-rw-r--r--silx/gui/plot3d/Plot3DToolBar.py119
-rw-r--r--silx/gui/plot3d/Plot3DWidget.py341
-rw-r--r--silx/gui/plot3d/Plot3DWindow.py94
-rw-r--r--silx/gui/plot3d/SFViewParamTree.py1467
-rw-r--r--silx/gui/plot3d/ScalarFieldView.py1385
-rw-r--r--silx/gui/plot3d/ViewpointToolBar.py114
-rw-r--r--silx/gui/plot3d/__init__.py45
-rw-r--r--silx/gui/plot3d/scene/__init__.py34
-rw-r--r--silx/gui/plot3d/scene/axes.py224
-rw-r--r--silx/gui/plot3d/scene/camera.py350
-rw-r--r--silx/gui/plot3d/scene/core.py334
-rw-r--r--silx/gui/plot3d/scene/cutplane.py374
-rw-r--r--silx/gui/plot3d/scene/event.py225
-rw-r--r--silx/gui/plot3d/scene/function.py471
-rw-r--r--silx/gui/plot3d/scene/interaction.py652
-rw-r--r--silx/gui/plot3d/scene/primitives.py1764
-rw-r--r--silx/gui/plot3d/scene/setup.py41
-rw-r--r--silx/gui/plot3d/scene/test/__init__.py43
-rw-r--r--silx/gui/plot3d/scene/test/test_transform.py91
-rw-r--r--silx/gui/plot3d/scene/test/test_utils.py275
-rw-r--r--silx/gui/plot3d/scene/text.py534
-rw-r--r--silx/gui/plot3d/scene/transform.py968
-rw-r--r--silx/gui/plot3d/scene/utils.py516
-rw-r--r--silx/gui/plot3d/scene/viewport.py492
-rw-r--r--silx/gui/plot3d/scene/window.py420
-rw-r--r--silx/gui/plot3d/setup.py44
-rw-r--r--silx/gui/plot3d/test/__init__.py62
-rw-r--r--silx/gui/plot3d/utils/__init__.py28
-rw-r--r--silx/gui/plot3d/utils/mng.py121
-rw-r--r--silx/gui/qt/__init__.py61
-rw-r--r--silx/gui/qt/_macosx.py68
-rw-r--r--silx/gui/qt/_pyside_dynamic.py158
-rw-r--r--silx/gui/qt/_pyside_missing.py274
-rw-r--r--silx/gui/qt/_qt.py229
-rw-r--r--silx/gui/qt/_utils.py44
-rw-r--r--silx/gui/setup.py51
-rw-r--r--silx/gui/test/__init__.py108
-rw-r--r--silx/gui/test/test_console.py91
-rw-r--r--silx/gui/test/test_icons.py116
-rw-r--r--silx/gui/test/test_qt.py144
-rw-r--r--silx/gui/test/test_utils.py77
-rw-r--r--silx/gui/test/utils.py428
-rw-r--r--silx/gui/widgets/FrameBrowser.py307
-rw-r--r--silx/gui/widgets/HierarchicalTableView.py172
-rw-r--r--silx/gui/widgets/MedianFilterDialog.py74
-rw-r--r--silx/gui/widgets/PeriodicTable.py825
-rw-r--r--silx/gui/widgets/TableWidget.py488
-rw-r--r--silx/gui/widgets/ThreadPoolPushButton.py233
-rw-r--r--silx/gui/widgets/WaitingPushButton.py243
-rw-r--r--silx/gui/widgets/__init__.py27
-rw-r--r--silx/gui/widgets/setup.py41
-rw-r--r--silx/gui/widgets/test/__init__.py45
-rw-r--r--silx/gui/widgets/test/test_hierarchicaltableview.py117
-rw-r--r--silx/gui/widgets/test/test_periodictable.py163
-rw-r--r--silx/gui/widgets/test/test_tablewidget.py61
-rw-r--r--silx/gui/widgets/test/test_threadpoolpushbutton.py129
182 files changed, 63585 insertions, 0 deletions
diff --git a/silx/gui/__init__.py b/silx/gui/__init__.py
new file mode 100644
index 0000000..6baf238
--- /dev/null
+++ b/silx/gui/__init__.py
@@ -0,0 +1,29 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Set of Qt widgets"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "23/05/2016"
diff --git a/silx/gui/_glutils/Context.py b/silx/gui/_glutils/Context.py
new file mode 100644
index 0000000..7600992
--- /dev/null
+++ b/silx/gui/_glutils/Context.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Abstraction of OpenGL context.
+
+It defines a way to get current OpenGL context to support multiple
+OpenGL contexts.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+# context #####################################################################
+
+
+def _defaultGLContextGetter():
+ return None
+
+_glContextGetter = _defaultGLContextGetter
+
+
+def getGLContext():
+ """Returns platform dependent object of current OpenGL context.
+
+ This is useful to associate OpenGL resources with the context they are
+ created in.
+
+ :return: Platform specific OpenGL context
+ :rtype: None by default or a platform dependent object"""
+ return _glContextGetter()
+
+
+def setGLContextGetter(getter=_defaultGLContextGetter):
+ """Set a platform dependent function to retrieve the current OpenGL context
+
+ :param getter: Platform dependent GL context getter
+ :type getter: Function with no args returning the current OpenGL context
+ """
+ global _glContextGetter
+ _glContextGetter = getter
diff --git a/silx/gui/_glutils/FramebufferTexture.py b/silx/gui/_glutils/FramebufferTexture.py
new file mode 100644
index 0000000..b01eb41
--- /dev/null
+++ b/silx/gui/_glutils/FramebufferTexture.py
@@ -0,0 +1,164 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Association of a texture and a framebuffer object for off-screen rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+
+from . import gl
+from .Texture import Texture
+
+
+_logger = logging.getLogger(__name__)
+
+
+class FramebufferTexture(object):
+ """Framebuffer with a texture.
+
+ Aimed at off-screen rendering to texture.
+
+ :param internalFormat: OpenGL texture internal format
+ :param shape: Shape (height, width) of the framebuffer and texture
+ :type shape: 2-tuple of int
+ :param stencilFormat: Stencil renderbuffer format
+ :param depthFormat: Depth renderbuffer format
+ :param kwargs: Extra arguments for :class:`Texture` constructor
+ """
+
+ _PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL
+
+ def __init__(self,
+ internalFormat,
+ shape,
+ stencilFormat=gl.GL_DEPTH24_STENCIL8,
+ depthFormat=gl.GL_DEPTH24_STENCIL8,
+ **kwargs):
+
+ self._texture = Texture(internalFormat, shape=shape, **kwargs)
+
+ self._previousFramebuffer = 0 # Used by with statement
+
+ self._name = gl.glGenFramebuffers(1)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._name)
+
+ # Attachments
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER,
+ gl.GL_COLOR_ATTACHMENT0,
+ gl.GL_TEXTURE_2D,
+ self._texture.name,
+ 0)
+
+ height, width = self._texture.shape
+
+ if stencilFormat is not None:
+ self._stencilId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ stencilFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_STENCIL_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._stencilId)
+ else:
+ self._stencilId = None
+
+ if depthFormat is not None:
+ if self._stencilId and depthFormat in self._PACKED_FORMAT:
+ self._depthId = self._stencilId
+ else:
+ self._depthId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ depthFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_DEPTH_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._depthId)
+ else:
+ self._depthId = None
+
+ assert gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) == \
+ gl.GL_FRAMEBUFFER_COMPLETE
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
+
+ @property
+ def shape(self):
+ """Shape of the framebuffer (height, width)"""
+ return self._texture.shape
+
+ @property
+ def texture(self):
+ """The texture this framebuffer is rendering to.
+
+ The life-cycle of the texture is managed by this object"""
+ return self._texture
+
+ @property
+ def name(self):
+ """OpenGL name of the framebuffer"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError("No OpenGL framebuffer resource, \
+ discard has already been called")
+
+ def bind(self):
+ """Bind this framebuffer for rendering"""
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.name)
+
+ # with statement
+
+ def __enter__(self):
+ self._previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ self.bind()
+
+ def __exit__(self, exctype, excvalue, traceback):
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._previousFramebuffer)
+
+ def discard(self):
+ """Delete associated OpenGL resources including texture"""
+ if self._name is not None:
+ gl.glDeleteFramebuffers(self._name)
+ self._name = None
+
+ if self._stencilId is not None:
+ gl.glDeleteRenderbuffers(self._stencilId)
+ if self._stencilId == self._depthId:
+ self._depthId = None
+ self._stencilId = None
+ if self._depthId is not None:
+ gl.glDeleteRenderbuffers(self._depthId)
+ self._depthId = None
+
+ self._texture.discard() # Also discard the texture
+ else:
+ _logger.warning("Discard has already been called")
diff --git a/silx/gui/_glutils/Program.py b/silx/gui/_glutils/Program.py
new file mode 100644
index 0000000..48c12f5
--- /dev/null
+++ b/silx/gui/_glutils/Program.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class to handle shader program compilation."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+
+import numpy
+
+from . import gl
+from .Context import getGLContext
+
+_logger = logging.getLogger(__name__)
+
+
+class Program(object):
+ """Wrap OpenGL shader program.
+
+ The program is compiled lazily (i.e., at first program :meth:`use`).
+ When the program is compiled, it stores attributes and uniforms locations.
+ So, attributes and uniforms must be used after :meth:`use`.
+
+ This object supports multiple OpenGL contexts.
+
+ :param str vertexShader: The source of the vertex shader.
+ :param str fragmentShader: The source of the fragment shader.
+ :param str attrib0:
+ Attribute's name to bind to position 0 (default: 'position').
+ On certain platform, this attribute MUST be active and with an
+ array attached to it in order for the rendering to occur....
+ """
+
+ def __init__(self, vertexShader, fragmentShader,
+ attrib0='position'):
+ self._vertexShader = vertexShader
+ self._fragmentShader = fragmentShader
+ self._attrib0 = attrib0
+ self._programs = {}
+
+ @staticmethod
+ def _compileGL(vertexShader, fragmentShader, attrib0):
+ program = gl.glCreateProgram()
+
+ gl.glBindAttribLocation(program, 0, attrib0.encode('ascii'))
+
+ vertex = gl.glCreateShader(gl.GL_VERTEX_SHADER)
+ gl.glShaderSource(vertex, vertexShader)
+ gl.glCompileShader(vertex)
+ if gl.glGetShaderiv(vertex, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetShaderInfoLog(vertex))
+ gl.glAttachShader(program, vertex)
+ gl.glDeleteShader(vertex)
+
+ fragment = gl.glCreateShader(gl.GL_FRAGMENT_SHADER)
+ gl.glShaderSource(fragment, fragmentShader)
+ gl.glCompileShader(fragment)
+ if gl.glGetShaderiv(fragment,
+ gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetShaderInfoLog(fragment))
+ gl.glAttachShader(program, fragment)
+ gl.glDeleteShader(fragment)
+
+ gl.glLinkProgram(program)
+ if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetProgramInfoLog(program))
+
+ attributes = {}
+ for index in range(gl.glGetProgramiv(program,
+ gl.GL_ACTIVE_ATTRIBUTES)):
+ name = gl.glGetActiveAttrib(program, index)[0]
+ namestr = name.decode('ascii')
+ attributes[namestr] = gl.glGetAttribLocation(program, name)
+
+ uniforms = {}
+ for index in range(gl.glGetProgramiv(program, gl.GL_ACTIVE_UNIFORMS)):
+ name = gl.glGetActiveUniform(program, index)[0]
+ namestr = name.decode('ascii')
+ uniforms[namestr] = gl.glGetUniformLocation(program, name)
+
+ return program, attributes, uniforms
+
+ def _getProgramInfo(self):
+ glcontext = getGLContext()
+ if glcontext not in self._programs:
+ raise RuntimeError(
+ "Program was not compiled for current OpenGL context.")
+ return self._programs[glcontext]
+
+ @property
+ def attributes(self):
+ """Vertex attributes names and locations as a dict of {str: int}.
+
+ WARNING:
+ Read-only usage.
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[1]
+
+ @property
+ def uniforms(self):
+ """Program uniforms names and locations as a dict of {str: int}.
+
+ WARNING:
+ Read-only usage.
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[2]
+
+ @property
+ def program(self):
+ """OpenGL id of the program.
+
+ WARNING:
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[0]
+
+ # def discard(self):
+ # pass # Not implemented yet
+
+ def use(self):
+ """Make use of the program, compiling it if necessary"""
+ glcontext = getGLContext()
+
+ if glcontext not in self._programs:
+ self._programs[glcontext] = self._compileGL(
+ self._vertexShader,
+ self._fragmentShader,
+ self._attrib0)
+
+ if _logger.getEffectiveLevel() <= logging.DEBUG:
+ gl.glValidateProgram(self.program)
+ if gl.glGetProgramiv(
+ self.program, gl.GL_VALIDATE_STATUS) != gl.GL_TRUE:
+ _logger.debug('Cannot validate program: %s',
+ gl.glGetProgramInfoLog(self.program))
+
+ gl.glUseProgram(self.program)
+
+ def setUniformMatrix(self, name, value, transpose=True, safe=False):
+ """Wrap glUniformMatrix[2|3|4]fv
+
+ :param str name: The name of the uniform.
+ :param value: The 2D matrix (or the array of matrices, 3D).
+ Matrices are 2x2, 3x3 or 4x4.
+ :type value: numpy.ndarray with 2 or 3 dimensions of float32
+ :param bool transpose: Whether to transpose (True, default) or not.
+ :param bool safe: False: raise an error if no uniform with this name;
+ True: silently ignores it.
+
+ :raises KeyError: if no uniform corresponds to name.
+ """
+ assert value.dtype == numpy.float32
+
+ shape = value.shape
+ assert len(shape) in (2, 3)
+ assert shape[-1] in (2, 3, 4)
+ assert shape[-1] == shape[-2] # As in OpenGL|ES 2.0
+
+ location = self.uniforms.get(name)
+ if location is not None:
+ count = 1 if len(shape) == 2 else shape[0]
+ transpose = gl.GL_TRUE if transpose else gl.GL_FALSE
+
+ if shape[-1] == 2:
+ gl.glUniformMatrix2fv(location, count, transpose, value)
+ elif shape[-1] == 3:
+ gl.glUniformMatrix3fv(location, count, transpose, value)
+ elif shape[-1] == 4:
+ gl.glUniformMatrix4fv(location, count, transpose, value)
+
+ elif not safe:
+ raise KeyError('No uniform: %s' % name)
diff --git a/silx/gui/_glutils/Texture.py b/silx/gui/_glutils/Texture.py
new file mode 100644
index 0000000..9f09a86
--- /dev/null
+++ b/silx/gui/_glutils/Texture.py
@@ -0,0 +1,308 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class wrapping OpenGL 2D and 3D texture."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/10/2016"
+
+
+import collections
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from . import gl, utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Texture(object):
+ """Base class to wrap OpenGL 2D and 3D texture
+
+ :param internalFormat: OpenGL texture internal format
+ :param data: The data to copy to the texture or None for an empty texture
+ :type data: numpy.ndarray or None
+ :param format_: Input data format if different from internalFormat
+ :param shape: If data is None, shape of the texture
+ :type shape: 2 or 3-tuple of int (height, width) or (depth, height, width)
+ :param int texUnit: The texture unit to use
+ :param minFilter: OpenGL texture minimization filter (default: GL_NEAREST)
+ :param magFilter: OpenGL texture magnification filter (default: GL_LINEAR)
+ :param wrap: Texture wrap mode for dimensions: (t, s) or (r, t, s)
+ If a single value is provided, it used for all dimensions.
+ :type wrap: OpenGL wrap mode or 2 or 3-tuple of wrap mode
+ """
+
+ def __init__(self, internalFormat, data=None, format_=None,
+ shape=None, texUnit=0,
+ minFilter=None, magFilter=None, wrap=None):
+
+ self._internalFormat = internalFormat
+ if format_ is None:
+ format_ = self.internalFormat
+
+ if data is None:
+ assert shape is not None
+ else:
+ assert shape is None
+ data = numpy.array(data, copy=False, order='C')
+ if format_ != gl.GL_RED:
+ shape = data.shape[:-1] # Last dimension is channels
+ else:
+ shape = data.shape
+
+ assert len(shape) in (2, 3)
+ self._shape = tuple(shape)
+ self._ndim = len(shape)
+
+ self.texUnit = texUnit
+
+ self._name = gl.glGenTextures(1)
+ self.bind(self.texUnit)
+
+ self._minFilter = None
+ self.minFilter = minFilter if minFilter is not None else gl.GL_NEAREST
+
+ self._magFilter = None
+ self.magFilter = magFilter if magFilter is not None else gl.GL_LINEAR
+
+ if wrap is not None:
+ if not isinstance(wrap, collections.Iterable):
+ wrap = [wrap] * self.ndim
+
+ assert len(wrap) == self.ndim
+
+ gl.glTexParameter(self.target,
+ gl.GL_TEXTURE_WRAP_S,
+ wrap[-1])
+ gl.glTexParameter(self.target,
+ gl.GL_TEXTURE_WRAP_T,
+ wrap[-2])
+ if self.ndim == 3:
+ gl.glTexParameter(self.target,
+ gl.GL_TEXTURE_WRAP_R,
+ wrap[0])
+
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+
+ # This are the defaults, useless to set if not modified
+ # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
+
+ if data is None:
+ data = c_void_p(0)
+ type_ = gl.GL_UNSIGNED_BYTE
+ else:
+ type_ = utils.numpyToGLType(data.dtype)
+
+ if self.ndim == 2:
+ _logger.debug(
+ 'Creating 2D texture shape: (%d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage2D(
+ gl.GL_TEXTURE_2D,
+ 0,
+ self.internalFormat,
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+ else:
+ _logger.debug(
+ 'Creating 3D texture shape: (%d, %d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1], self.shape[2],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage3D(
+ gl.GL_TEXTURE_3D,
+ 0,
+ self.internalFormat,
+ self.shape[2],
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+
+ gl.glBindTexture(self.target, 0)
+
+ @property
+ def target(self):
+ """OpenGL target type of this texture"""
+ return gl.GL_TEXTURE_2D if self.ndim == 2 else gl.GL_TEXTURE_3D
+
+ @property
+ def ndim(self):
+ """The number of dimensions: 2 or 3"""
+ return self._ndim
+
+ @property
+ def internalFormat(self):
+ """Texture internal format"""
+ return self._internalFormat
+
+ @property
+ def shape(self):
+ """Shape of the texture: (height, width) or (depth, height, width)"""
+ return self._shape
+
+ @property
+ def name(self):
+ """OpenGL texture name"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError(
+ "No OpenGL texture resource, discard has already been called")
+
+ @property
+ def minFilter(self):
+ """Minifying function parameter (GL_TEXTURE_MIN_FILTER)"""
+ return self._minFilter
+
+ @minFilter.setter
+ def minFilter(self, minFilter):
+ if minFilter != self.minFilter:
+ self._minFilter = minFilter
+ self.bind()
+ gl.glTexParameter(self.target,
+ gl.GL_TEXTURE_MIN_FILTER,
+ self.minFilter)
+
+ @property
+ def magFilter(self):
+ """Magnification function parameter (GL_TEXTURE_MAG_FILTER)"""
+ return self._magFilter
+
+ @magFilter.setter
+ def magFilter(self, magFilter):
+ if magFilter != self.magFilter:
+ self._magFilter = magFilter
+ self.bind()
+ gl.glTexParameter(self.target,
+ gl.GL_TEXTURE_MAG_FILTER,
+ self.magFilter)
+
+ def discard(self):
+ """Delete associated OpenGL texture"""
+ if self._name is not None:
+ gl.glDeleteTextures(self._name)
+ self._name = None
+ else:
+ _logger.warning("Discard as already been called")
+
+ def bind(self, texUnit=None):
+ """Bind the texture to a texture unit.
+
+ :param int texUnit: The texture unit to use
+ """
+ if texUnit is None:
+ texUnit = self.texUnit
+ gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit)
+ gl.glBindTexture(self.target, self.name)
+
+ # with statement
+
+ def __enter__(self):
+ self.bind()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ gl.glActiveTexture(gl.GL_TEXTURE0 + self.texUnit)
+ gl.glBindTexture(self.target, 0)
+
+ def update(self,
+ format_,
+ data,
+ offset=(0, 0, 0),
+ texUnit=None):
+ """Update the content of the texture.
+
+ Texture is not resized, so data must fit into texture with the
+ given offset.
+
+ :param format_: The OpenGL format of the data
+ :param data: The data to use to update the texture
+ :param offset: The offset in the texture where to copy the data
+ :type offset: 2 or 3-tuple of int
+ :param int texUnit:
+ The texture unit to use (default: the one provided at init)
+ """
+ data = numpy.array(data, copy=False, order='C')
+
+ assert data.ndim == self.ndim
+ assert len(offset) >= self.ndim
+ for i in range(self.ndim):
+ assert offset[i] + data.shape[i] <= self.shape[i]
+
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+
+ # This are the defaults, useless to set if not modified
+ # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
+
+ self.bind(texUnit)
+
+ type_ = utils.numpyToGLType(data.dtype)
+
+ if self.ndim == 2:
+ gl.glTexSubImage2D(gl.GL_TEXTURE_2D,
+ 0,
+ offset[1],
+ offset[0],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+ gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
+ else:
+ gl.glTexSubImage3D(gl.GL_TEXTURE_3D,
+ 0,
+ offset[2],
+ offset[1],
+ offset[0],
+ data.shape[2],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+ gl.glBindTexture(gl.GL_TEXTURE_3D, 0)
diff --git a/silx/gui/_glutils/VertexBuffer.py b/silx/gui/_glutils/VertexBuffer.py
new file mode 100644
index 0000000..689b543
--- /dev/null
+++ b/silx/gui/_glutils/VertexBuffer.py
@@ -0,0 +1,266 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class managing an OpenGL vertex buffer."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+
+import logging
+from ctypes import c_void_p
+import numpy
+
+from . import gl
+from .utils import numpyToGLType, sizeofGLType
+
+
+_logger = logging.getLogger(__name__)
+
+
+class VertexBuffer(object):
+ """Object handling an OpenGL vertex buffer object
+
+ :param data: Data used to fill the vertex buffer
+ :type data: numpy.ndarray or None
+ :param int size: Size in bytes of the buffer or None for data size
+ :param usage: OpenGL vertex buffer expected usage pattern:
+ GL_STREAM_DRAW, GL_STATIC_DRAW (default) or GL_DYNAMIC_DRAW
+ :param target: Target buffer:
+ GL_ARRAY_BUFFER (default) or GL_ELEMENT_ARRAY_BUFFER
+ """
+ # OpenGL|ES 2.0 subset:
+ _USAGES = gl.GL_STREAM_DRAW, gl.GL_STATIC_DRAW, gl.GL_DYNAMIC_DRAW
+ _TARGETS = gl.GL_ARRAY_BUFFER, gl.GL_ELEMENT_ARRAY_BUFFER
+
+ def __init__(self,
+ data=None,
+ size=None,
+ usage=None,
+ target=None):
+ if usage is None:
+ usage = gl.GL_STATIC_DRAW
+ assert usage in self._USAGES
+
+ if target is None:
+ target = gl.GL_ARRAY_BUFFER
+ assert target in self._TARGETS
+
+ self._target = target
+ self._usage = usage
+
+ self._name = gl.glGenBuffers(1)
+ self.bind()
+
+ if data is None:
+ assert size is not None
+ self._size = size
+ gl.glBufferData(self._target,
+ self._size,
+ c_void_p(0),
+ self._usage)
+ else:
+ data = numpy.array(data, copy=False, order='C')
+ if size is not None:
+ assert size <= data.nbytes
+
+ self._size = size or data.nbytes
+ gl.glBufferData(self._target,
+ self._size,
+ data,
+ self._usage)
+
+ gl.glBindBuffer(self._target, 0)
+
+ @property
+ def target(self):
+ """The target buffer of the vertex buffer"""
+ return self._target
+
+ @property
+ def usage(self):
+ """The expected usage of the vertex buffer"""
+ return self._usage
+
+ @property
+ def name(self):
+ """OpenGL Vertex Buffer object name (int)"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError("No OpenGL buffer resource, \
+ discard has already been called")
+
+ @property
+ def size(self):
+ """Size in bytes of the Vertex Buffer Object (int)"""
+ if self._size is not None:
+ return self._size
+ else:
+ raise RuntimeError("No OpenGL buffer resource, \
+ discard has already been called")
+
+ def bind(self):
+ """Bind the vertex buffer"""
+ gl.glBindBuffer(self._target, self.name)
+
+ def update(self, data, offset=0, size=None):
+ """Update vertex buffer content.
+
+ :param numpy.ndarray data: The data to put in the vertex buffer
+ :param int offset: Offset in bytes in the buffer where to put the data
+ :param int size: If provided, size of data to copy
+ """
+ data = numpy.array(data, copy=False, order='C')
+ if size is None:
+ size = data.nbytes
+ assert offset + size <= self.size
+ with self:
+ gl.glBufferSubData(self._target, offset, size, data)
+
+ def discard(self):
+ """Delete the vertex buffer"""
+ if self._name is not None:
+ gl.glDeleteBuffers(self._name)
+ self._name = None
+ self._size = None
+ else:
+ _logger.warning("Discard has already been called")
+
+ # with statement
+
+ def __enter__(self):
+ self.bind()
+
+ def __exit__(self, exctype, excvalue, traceback):
+ gl.glBindBuffer(self._target, 0)
+
+
+class VertexBufferAttrib(object):
+ """Describes data stored in a vertex buffer
+
+ Convenient class to store info for glVertexAttribPointer calls
+
+ :param VertexBuffer vbo: The vertex buffer storing the data
+ :param int type_: The OpenGL type of the data
+ :param int size: The number of data elements stored in the VBO
+ :param int dimension: The number of `type_` element(s) in [1, 4]
+ :param int offset: Start offset of data in the vertex buffer
+ :param int stride: Data stride in the vertex buffer
+ """
+
+ _GL_TYPES = gl.GL_UNSIGNED_BYTE, gl.GL_FLOAT, gl.GL_INT
+
+ def __init__(self,
+ vbo,
+ type_,
+ size,
+ dimension=1,
+ offset=0,
+ stride=0,
+ normalisation=False):
+ self.vbo = vbo
+ assert type_ in self._GL_TYPES
+ self.type_ = type_
+ self.size = size
+ assert 1 <= dimension <= 4
+ self.dimension = dimension
+ self.offset = offset
+ self.stride = stride
+ self.normalisation = bool(normalisation)
+
+ @property
+ def itemsize(self):
+ """Size in bytes of a vertex buffer element (int)"""
+ return self.dimension * sizeofGLType(self.type_)
+
+ itemSize = itemsize # Backward compatibility
+
+ def setVertexAttrib(self, attribute):
+ """Call glVertexAttribPointer with objects information"""
+ normalisation = gl.GL_TRUE if self.normalisation else gl.GL_FALSE
+ with self.vbo:
+ gl.glVertexAttribPointer(attribute,
+ self.dimension,
+ self.type_,
+ normalisation,
+ self.stride,
+ c_void_p(self.offset))
+
+ def copy(self):
+ return VertexBufferAttrib(self.vbo,
+ self.type_,
+ self.size,
+ self.dimension,
+ self.offset,
+ self.stride,
+ self.normalisation)
+
+
+def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
+ """Create a single vertex buffer from multiple 1D or 2D numpy arrays.
+
+ It is possible to reserve memory before and after each array in the VBO
+
+ :param arrays: Arrays of data to store
+ :type arrays: Iterable of numpy.ndarray
+ :param prefix: If given, number of elements to reserve before each array
+ :type prefix: Iterable of int or None
+ :param suffix: If given, number of elements to reserve after each array
+ :type suffix: Iterable of int or None
+ :param int usage: vertex buffer expected usage or None for default
+ :returns: List of VertexBufferAttrib objects sharing the same vertex buffer
+ """
+ info = []
+ vbosize = 0
+
+ if prefix is None:
+ prefix = (0,) * len(arrays)
+ if suffix is None:
+ suffix = (0,) * len(arrays)
+
+ for data, pre, post in zip(arrays, prefix, suffix):
+ data = numpy.array(data, copy=False, order='C')
+ shape = data.shape
+ assert len(shape) <= 2
+ type_ = numpyToGLType(data.dtype)
+ size = shape[0] + pre + post
+ dimension = 1 if len(shape) == 1 else shape[1]
+ sizeinbytes = size * dimension * sizeofGLType(type_)
+ sizeinbytes = 4 * ((sizeinbytes + 3) >> 2) # 4 bytes alignment
+ copyoffset = vbosize + pre * dimension * sizeofGLType(type_)
+ info.append((data, type_, size, dimension,
+ vbosize, sizeinbytes, copyoffset))
+ vbosize += sizeinbytes
+
+ vbo = VertexBuffer(size=vbosize, usage=usage)
+
+ result = []
+ for data, type_, size, dimension, offset, sizeinbytes, copyoffset in info:
+ copysize = data.shape[0] * dimension * sizeofGLType(type_)
+ vbo.update(data, offset=copyoffset, size=copysize)
+ result.append(
+ VertexBufferAttrib(vbo, type_, size, dimension, offset, 0))
+ return result
diff --git a/silx/gui/_glutils/__init__.py b/silx/gui/_glutils/__init__.py
new file mode 100644
index 0000000..e86a58f
--- /dev/null
+++ b/silx/gui/_glutils/__init__.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides utility functions to handle OpenGL resources.
+
+The :mod:`gl` module provides a wrapper to OpenGL based on PyOpenGL.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+# OpenGL convenient functions
+from .Context import getGLContext, setGLContextGetter # noqa
+from .FramebufferTexture import FramebufferTexture # noqa
+from .Program import Program # noqa
+from .Texture import Texture # noqa
+from .VertexBuffer import VertexBuffer, VertexBufferAttrib, vertexBuffer # noqa
+from .utils import sizeofGLType, isSupportedGLType, numpyToGLType # noqa
diff --git a/silx/gui/_glutils/font.py b/silx/gui/_glutils/font.py
new file mode 100644
index 0000000..566ae49
--- /dev/null
+++ b/silx/gui/_glutils/font.py
@@ -0,0 +1,152 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Text rasterisation feature leveraging Qt font and text layout support."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+import logging
+import sys
+import numpy
+from .. import qt
+from .._utils import convertQImageToArray
+
+
+_logger = logging.getLogger(__name__)
+
+
+def getDefaultFontFamily():
+ """Returns the default font family of the application"""
+ return qt.QApplication.instance().font().family()
+
+
+# Font weights
+ULTRA_LIGHT = 0
+"""Lightest characters: Minimum font weight"""
+
+LIGHT = 25
+"""Light characters"""
+
+NORMAL = 50
+"""Normal characters"""
+
+SEMI_BOLD = 63
+"""Between normal and bold characters"""
+
+BOLD = 74
+"""Thicker characters"""
+
+BLACK = 87
+"""Really thick characters"""
+
+ULTRA_BLACK = 99
+"""Thickest characters: Maximum font weight"""
+
+
+def rasterText(text, font,
+ size=-1,
+ weight=-1,
+ italic=False,
+ devicePixelRatio=1.0):
+ """Raster text using Qt.
+
+ It supports multiple lines.
+
+ :param str text: The text to raster
+ :param font: Font name or QFont to use
+ :type font: str or :class:`QFont`
+ :param int size:
+ Font size in points
+ Used only if font is given as name.
+ :param int weight:
+ Font weight in [0, 99], see QFont.Weight.
+ Used only if font is given as name.
+ :param bool italic:
+ True for italic font (default: False).
+ Used only if font is given as name.
+ :param float devicePixelRatio:
+ The current ratio between device and device-independent pixel
+ (default: 1.0)
+ :return: Corresponding image in gray scale and baseline offset from top
+ :rtype: (HxW numpy.ndarray of uint8, int)
+ """
+ if not text:
+ _logger.info("Trying to raster empty text, replaced by white space")
+ text = ' ' # Replace empty text by white space to produce an image
+
+ if not isinstance(font, qt.QFont):
+ font = qt.QFont(font, size, weight, italic)
+
+ metrics = qt.QFontMetrics(font)
+ size = metrics.size(qt.Qt.TextExpandTabs, text)
+ bounds = metrics.boundingRect(
+ qt.QRect(0, 0, size.width(), size.height()),
+ qt.Qt.TextExpandTabs,
+ text)
+
+ if (devicePixelRatio != 1.0 and
+ not hasattr(qt.QImage, 'setDevicePixelRatio')): # Qt 4
+ _logger.error('devicePixelRatio not supported')
+ devicePixelRatio = 1.0
+
+ # Add extra border and handle devicePixelRatio
+ width = bounds.width() * devicePixelRatio + 2
+ # align line size to 32 bits to ease conversion to numpy array
+ width = 4 * ((width + 3) // 4)
+ image = qt.QImage(width,
+ bounds.height() * devicePixelRatio,
+ qt.QImage.Format_RGB888)
+ if (devicePixelRatio != 1.0 and
+ hasattr(image, 'setDevicePixelRatio')): # Qt 5
+ image.setDevicePixelRatio(devicePixelRatio)
+
+ # TODO if Qt5 use Format_Grayscale8 instead
+ image.fill(0)
+
+ # Raster text
+ painter = qt.QPainter()
+ painter.begin(image)
+ painter.setPen(qt.Qt.white)
+ painter.setFont(font)
+ painter.drawText(bounds, qt.Qt.TextExpandTabs, text)
+ painter.end()
+
+ array = convertQImageToArray(image)
+
+ # RGB to R
+ array = numpy.ascontiguousarray(array[:, :, 0])
+
+ # Remove leading and trailing empty columns but one on each side
+ column_cumsum = numpy.cumsum(numpy.sum(array, axis=0))
+ array = array[:, column_cumsum.argmin():column_cumsum.argmax() + 2]
+
+ # Remove leading and trailing empty rows but one on each side
+ row_cumsum = numpy.cumsum(numpy.sum(array, axis=1))
+ min_row = row_cumsum.argmin()
+ array = array[min_row:row_cumsum.argmax() + 2, :]
+
+ return array, metrics.ascent() - min_row
diff --git a/silx/gui/_glutils/gl.py b/silx/gui/_glutils/gl.py
new file mode 100644
index 0000000..4b9a7bb
--- /dev/null
+++ b/silx/gui/_glutils/gl.py
@@ -0,0 +1,165 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module loads PyOpenGL and provides a namespace for OpenGL."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+from contextlib import contextmanager as _contextmanager
+from ctypes import c_uint
+import logging
+
+_logger = logging.getLogger(__name__)
+
+import OpenGL
+# Set the following to true for debugging
+if _logger.getEffectiveLevel() <= logging.DEBUG:
+ _logger.debug('Enabling PyOpenGL debug flags')
+ OpenGL.ERROR_LOGGING = True
+ OpenGL.ERROR_CHECKING = True
+ OpenGL.ERROR_ON_COPY = True
+else:
+ OpenGL.ERROR_LOGGING = False
+ OpenGL.ERROR_CHECKING = False
+ OpenGL.ERROR_ON_COPY = False
+
+import OpenGL.GL as _GL
+from OpenGL.GL import * # noqa
+
+# Extentions core in OpenGL 3
+from OpenGL.GL.ARB import framebuffer_object as _FBO
+from OpenGL.GL.ARB.framebuffer_object import * # noqa
+from OpenGL.GL.ARB.texture_rg import GL_R32F, GL_R16F # noqa
+from OpenGL.GL.ARB.texture_rg import GL_R16, GL_R8 # noqa
+
+# PyOpenGL 3.0.1 does not define it
+try:
+ GLchar
+except NameError:
+ from ctypes import c_char
+ GLchar = c_char
+
+
+def testGL():
+ """Test if required OpenGL version and extensions are available.
+
+ This MUST be run with an active OpenGL context.
+ """
+ version = glGetString(GL_VERSION).split()[0] # get version number
+ major, minor = int(version[0]), int(version[2])
+ if major < 2 or (major == 2 and minor < 1):
+ raise RuntimeError(
+ "Requires at least OpenGL version 2.1, running with %s" % version)
+
+ from OpenGL.GL.ARB.framebuffer_object import glInitFramebufferObjectARB
+ from OpenGL.GL.ARB.texture_rg import glInitTextureRgARB
+
+ if not glInitFramebufferObjectARB():
+ raise RuntimeError(
+ "OpenGL GL_ARB_framebuffer_object extension required !")
+
+ if not glInitTextureRgARB():
+ raise RuntimeError("OpenGL GL_ARB_texture_rg extension required !")
+
+
+# Additional setup
+if hasattr(glget, 'addGLGetConstant'):
+ glget.addGLGetConstant(GL_FRAMEBUFFER_BINDING, (1,))
+
+
+@_contextmanager
+def enabled(capacity, enable=True):
+ """Context manager enabling an OpenGL capacity.
+
+ This is not checking the current state of the capacity.
+
+ :param capacity: The OpenGL capacity enum to enable/disable
+ :param bool enable:
+ True (default) to enable during context, False to disable
+ """
+ if enable:
+ glEnable(capacity)
+ yield
+ glDisable(capacity)
+ else:
+ glDisable(capacity)
+ yield
+ glEnable(capacity)
+
+
+def disabled(capacity, disable=True):
+ """Context manager disabling an OpenGL capacity.
+
+ This is not checking the current state of the capacity.
+
+ :param capacity: The OpenGL capacity enum to disable/enable
+ :param bool disable:
+ True (default) to disable during context, False to enable
+ """
+ return enabled(capacity, not disable)
+
+
+# Additional OpenGL wrapping
+
+def glGetActiveAttrib(program, index):
+ """Wrap PyOpenGL glGetActiveAttrib"""
+ bufsize = glGetProgramiv(program, GL_ACTIVE_ATTRIBUTE_MAX_LENGTH)
+ length = GLsizei()
+ size = GLint()
+ type_ = GLenum()
+ name = (GLchar * bufsize)()
+
+ _GL.glGetActiveAttrib(program, index, bufsize, length, size, type_, name)
+ return name.value, size.value, type_.value
+
+
+def glDeleteRenderbuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _FBO.glDeleteRenderbuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteFramebuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _FBO.glDeleteFramebuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteBuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _GL.glDeleteBuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteTextures(textures):
+ if not hasattr(textures, '__len__'): # Support single int argument
+ textures = [textures]
+ length = len(textures)
+ _GL.glDeleteTextures((c_uint * length)(*textures))
diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py
new file mode 100644
index 0000000..73af338
--- /dev/null
+++ b/silx/gui/_glutils/utils.py
@@ -0,0 +1,70 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides conversion functions between OpenGL and numpy types.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+from . import gl
+import numpy
+
+
+_GL_TYPE_SIZES = {
+ gl.GL_FLOAT: 4,
+ gl.GL_BYTE: 1,
+ gl.GL_SHORT: 2,
+ gl.GL_INT: 4,
+ gl.GL_UNSIGNED_BYTE: 1,
+ gl.GL_UNSIGNED_SHORT: 2,
+ gl.GL_UNSIGNED_INT: 4,
+}
+
+
+def sizeofGLType(type_):
+ """Returns the size in bytes of an element of type `type_`"""
+ return _GL_TYPE_SIZES[type_]
+
+
+_TYPE_CONVERTER = {
+ numpy.dtype(numpy.float32): gl.GL_FLOAT,
+ numpy.dtype(numpy.int8): gl.GL_BYTE,
+ numpy.dtype(numpy.int16): gl.GL_SHORT,
+ numpy.dtype(numpy.int32): gl.GL_INT,
+ numpy.dtype(numpy.uint8): gl.GL_UNSIGNED_BYTE,
+ numpy.dtype(numpy.uint16): gl.GL_UNSIGNED_SHORT,
+ numpy.dtype(numpy.uint32): gl.GL_UNSIGNED_INT,
+}
+
+
+def isSupportedGLType(type_):
+ """Test if a numpy type or dtype can be converted to a GL type."""
+ return numpy.dtype(type_) in _TYPE_CONVERTER
+
+
+def numpyToGLType(type_):
+ """Returns the GL type corresponding the provided numpy type or dtype."""
+ return _TYPE_CONVERTER[numpy.dtype(type_)]
diff --git a/silx/gui/_utils.py b/silx/gui/_utils.py
new file mode 100644
index 0000000..e29141f
--- /dev/null
+++ b/silx/gui/_utils.py
@@ -0,0 +1,102 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides convenient functions to use with Qt objects.
+
+It provides conversion between numpy and QImage.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+import sys
+import numpy
+
+from . import qt
+
+
+def convertArrayToQImage(image):
+ """Convert an array-like RGB888 image to a QImage.
+
+ The created QImage is using a copy of the array data.
+
+ Limitation: Only supports RGB888 format.
+
+ :param image: Array-like image data
+ :type image: numpy.ndarray of uint8 of dimension HxWx3
+ :return: Corresponding Qt image
+ :rtype: QImage
+ """
+ # Possible extension: add a format argument to support more formats
+
+ image = numpy.array(image, copy=False, order='C', dtype=numpy.uint8)
+
+ height, width, depth = image.shape
+ assert depth == 3
+
+ qimage = qt.QImage(
+ image.data,
+ width,
+ height,
+ image.strides[0], # bytesPerLine
+ qt.QImage.Format_RGB888)
+
+ return qimage.copy() # Making a copy of the image and its data
+
+
+def convertQImageToArray(image):
+ """Convert a RGB888 QImage to a numpy array.
+
+ Limitation: Only supports RGB888 format.
+ If QImage is not RGB888 it gets converted to this format.
+
+ :param QImage: The QImage to convert.
+ :return: The image array
+ :rtype: numpy.ndarray of uint8 of shape HxWx3
+ """
+ # Possible extension: avoid conversion to support more formats
+
+ if image.format() != qt.QImage.Format_RGB888:
+ # Convert to RGB888 if needed
+ image = image.convertToFormat(qt.QImage.Format_RGB888)
+
+ ptr = image.bits()
+ if qt.BINDING != 'PySide':
+ ptr.setsize(image.byteCount())
+ if qt.BINDING == 'PyQt4' and sys.version_info[0] == 2:
+ ptr = ptr.asstring()
+ elif sys.version_info[0] == 3: # PySide with Python3
+ ptr = ptr.tobytes()
+
+ array = numpy.fromstring(ptr, dtype=numpy.uint8)
+
+ # Lines are 32 bits aligned: remove padding bytes
+ array = array.reshape(image.height(), -1)[:, :image.width() * 3]
+ array.shape = image.height(), image.width(), 3
+ return array
diff --git a/silx/gui/console.py b/silx/gui/console.py
new file mode 100644
index 0000000..13760b4
--- /dev/null
+++ b/silx/gui/console.py
@@ -0,0 +1,214 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an IPython console widget.
+
+You can push variables - any python object - to the
+console's interactive namespace. This provides users with an advanced way
+of interacting with your program. For instance, if your program has a
+:class:`PlotWidget` or a :class:`PlotWindow`, you can push a reference to
+these widgets to allow your users to add curves, save data to files… by using
+the widgets' methods from the console.
+
+.. note::
+
+ This module has a dependency on
+ `IPython <https://pypi.python.org/pypi/ipython>`_ and
+ `qtconsole <https://pypi.python.org/pypi/qtconsole>`_ (or *ipython.qt* for
+ older versions of *IPython*). An ``ImportError`` will be raised if it is
+ imported while the dependencies are not satisfied.
+
+Basic usage example::
+
+ from silx.gui import qt
+ from silx.gui.console import IPythonWidget
+
+ app = qt.QApplication([])
+
+ hello_button = qt.QPushButton("Hello World!", None)
+ hello_button.show()
+
+ console = IPythonWidget()
+ console.show()
+ console.pushVariables({"the_button": hello_button})
+
+ app.exec_()
+
+This program will display a console widget and a push button in two separate
+windows. You will be able to interact with the button from the console,
+for example change its text::
+
+ >>> the_button.setText("Spam spam")
+
+An IPython interactive console is a powerful tool that enables you to work
+with data and plot it.
+See `this tutorial <https://plot.ly/python/ipython-notebook-tutorial/>`_
+for more information on some of the rich features of IPython.
+"""
+__authors__ = ["Tim Rae", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/05/2016"
+
+import logging
+
+from . import qt
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import IPython
+except ImportError as e:
+ raise ImportError("Failed to import IPython, required by " + __name__)
+
+# This widget cannot be used inside an interactive IPython shell.
+# It would raise MultipleInstanceError("Multiple incompatible subclass
+# instances of InProcessInteractiveShell are being created").
+try:
+ __IPYTHON__
+except NameError:
+ pass # Not in IPython
+else:
+ msg = "Module " + __name__ + " cannot be used within an IPython shell"
+ raise ImportError(msg)
+
+# qtconsole is a separate module in recent versions of IPython/Jupyter
+# http://blog.jupyter.org/2015/04/15/the-big-split/
+if IPython.__version__.startswith("2"):
+ qtconsole = None
+else:
+ try:
+ import qtconsole
+ except ImportError:
+ qtconsole = None
+
+if qtconsole is not None:
+ try:
+ from qtconsole.rich_ipython_widget import RichJupyterWidget as \
+ RichIPythonWidget
+ except ImportError:
+ try:
+ from qtconsole.rich_ipython_widget import RichIPythonWidget
+ except ImportError as e:
+ qtconsole = None
+ else:
+ from qtconsole.inprocess import QtInProcessKernelManager
+ else:
+ from qtconsole.inprocess import QtInProcessKernelManager
+
+
+if qtconsole is None:
+ # Import the console machinery from ipython
+
+ # The `has_binding` test of IPython does not find the Qt bindings
+ # in case silx is used in a frozen binary
+ import IPython.external.qt_loaders
+
+ def has_binding(*var, **kw):
+ return True
+
+ IPython.external.qt_loaders.has_binding = has_binding
+
+ from IPython.qt.console.rich_ipython_widget import RichIPythonWidget
+ from IPython.qt.inprocess import QtInProcessKernelManager
+
+
+class IPythonWidget(RichIPythonWidget):
+ """Live IPython console widget.
+
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console.
+ """
+
+ def __init__(self, parent=None, custom_banner=None, *args, **kwargs):
+ if parent is not None:
+ kwargs["parent"] = parent
+ super(IPythonWidget, self).__init__(*args, **kwargs)
+ if custom_banner is not None:
+ self.banner = custom_banner
+ self.setWindowTitle(self.banner)
+ self.kernel_manager = kernel_manager = QtInProcessKernelManager()
+ kernel_manager.start_kernel()
+ self.kernel_client = kernel_client = self._kernel_manager.client()
+ kernel_client.start_channels()
+
+ def stop():
+ kernel_client.stop_channels()
+ kernel_manager.shutdown_kernel()
+ self.exit_requested.connect(stop)
+
+ def sizeHint(self):
+ """Return a reasonable default size for usage in :class:`PlotWindow`"""
+ return qt.QSize(500, 300)
+
+ def pushVariables(self, variable_dict):
+ """ Given a dictionary containing name / value pairs, push those
+ variables to the IPython console widget.
+
+ :param variable_dict: Dictionary of variables to be pushed to the
+ console's interactive namespace (```{variable_name: object, …}```)
+ """
+ self.kernel_manager.kernel.shell.push(variable_dict)
+
+
+class IPythonDockWidget(qt.QDockWidget):
+ """Dock Widget including a :class:`IPythonWidget` inside
+ a vertical layout.
+
+ :param available_vars: Dictionary of variables to be pushed to the
+ console's interactive namespace: ``{"variable_name": object, …}``
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console
+ :param title: Dock widget title
+ :param parent: Parent :class:`qt.QMainWindow` containing this
+ :class:`qt.QDockWidget`
+ """
+ def __init__(self, parent=None, available_vars=None, custom_banner=None,
+ title="Console"):
+ super(IPythonDockWidget, self).__init__(title, parent)
+
+ self.ipyconsole = IPythonWidget(custom_banner=custom_banner)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.ipyconsole)
+
+ if available_vars is not None:
+ self.ipyconsole.pushVariables(available_vars)
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
+
+
+def main():
+ """Run a Qt app with an IPython console"""
+ app = qt.QApplication([])
+ widget = IPythonDockWidget()
+ widget.show()
+ app.exec_()
+
+if __name__ == '__main__':
+ main()
diff --git a/silx/gui/data/ArrayTableModel.py b/silx/gui/data/ArrayTableModel.py
new file mode 100644
index 0000000..87a2fc1
--- /dev/null
+++ b/silx/gui/data/ArrayTableModel.py
@@ -0,0 +1,610 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module defines a data model for displaying and editing arrays of any
+number of dimensions in a table view.
+"""
+from __future__ import division
+import numpy
+import logging
+from silx.gui import qt
+from silx.gui.data.TextFormatter import TextFormatter
+
+__authors__ = ["V.A. Sole"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _is_array(data):
+ """Return True if object implements all necessary attributes to be used
+ as a numpy array.
+
+ :param object data: Array-like object (numpy array, h5py dataset...)
+ :return: boolean
+ """
+ # add more required attribute if necessary
+ for attr in ("shape", "dtype"):
+ if not hasattr(data, attr):
+ return False
+ return True
+
+
+class ArrayTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 2D slices in a N-dimensional
+ array.
+
+ A slice for a 3-D array is characterized by a perspective (the number of
+ the axis orthogonal to the slice) and an index at which the slice
+ intersects the orthogonal axis.
+
+ In the n-D case, only slices parallel to the last two axes are handled. A
+ slice is therefore characterized by a list of indices locating the
+ slice on all the :math:`n - 2` orthogonal axes.
+
+ :param parent: Parent QObject
+ :param data: Numpy array, or object implementing a similar interface
+ (e.g. h5py dataset)
+ :param str fmt: Format string for representing numerical values.
+ Default is ``"%g"``.
+ :param sequence[int] perspective: See documentation
+ of :meth:`setPerspective`.
+ """
+ def __init__(self, parent=None, data=None, perspective=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self._array = None
+ """n-dimensional numpy array"""
+
+ self._bgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the background color
+ """
+
+ self._fgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the foreground color
+ """
+
+ self._formatter = None
+ """Formatter for text representation of data"""
+
+ formatter = TextFormatter(self)
+ formatter.setUseQuoteForText(False)
+ self.setFormatter(formatter)
+
+ self._index = None
+ """This attribute stores the slice index, as a list of indices
+ where the frame intersects orthogonal axis."""
+
+ self._perspective = None
+ """Sequence of dimensions orthogonal to the frame to be viewed.
+ For an array with ``n`` dimensions, this is a sequence of ``n-2``
+ integers. the first dimension is numbered ``0``.
+ By default, the data frames use the last two dimensions as their axes
+ and therefore the perspective is a sequence of the first ``n-2``
+ dimensions.
+ For example, for a 5-D array, the default perspective is ``(0, 1, 2)``
+ and the default frames axes are ``(3, 4)``."""
+
+ # set _data and _perspective
+ self.setArrayData(data, perspective=perspective)
+
+ def _getRowDim(self):
+ """The row axis is the first axis parallel to the frames
+ (lowest dimension number)
+
+ Return None for 0-D (scalar) or 1-D arrays
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 2:
+ # scalar or 1D array: no row index
+ return None
+ # take all dimensions and remove the orthogonal ones
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert len(frame_axes) == 2
+ return min(frame_axes)
+
+ def _getColumnDim(self):
+ """The column axis is the second (highest dimension) axis parallel
+ to the frames
+
+ Return None for 0-D (scalar)
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 1:
+ # scalar: no column index
+ return None
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert (len(frame_axes) == 2) if n_dimensions > 1 else (len(frame_axes) == 1)
+ return max(frame_axes)
+
+ def _getIndexTuple(self, table_row, table_col):
+ """Return the n-dimensional index of a value in the original array,
+ based on its row and column indices in the table view
+
+ :param table_row: Row index (0-based) of a table cell
+ :param table_col: Column index (0-based) of a table cell
+ :return: Tuple of indices of the element in the numpy array
+ """
+ row_dim = self._getRowDim()
+ col_dim = self._getColumnDim()
+
+ # get indices on all orthogonal axes
+ selection = list(self._index)
+ # insert indices on parallel axes
+ if row_dim is not None:
+ selection.insert(row_dim, table_row)
+ if col_dim is not None:
+ selection.insert(col_dim, table_col)
+ return tuple(selection)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of rows to be displayed in table"""
+ row_dim = self._getRowDim()
+ if row_dim is None:
+ # 0-D and 1-D arrays
+ return 1
+ return self._array.shape[row_dim]
+
+ def columnCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of columns to be displayed in table"""
+ col_dim = self._getColumnDim()
+ if col_dim is None:
+ # 0-D array
+ return 1
+ return self._array.shape[col_dim]
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if index.isValid():
+ selection = self._getIndexTuple(index.row(),
+ index.column())
+ if role == qt.Qt.DisplayRole:
+ return self._formatter.toString(self._array[selection])
+
+ if role == qt.Qt.BackgroundRole and self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ if self._bgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._bgcolors.shape[-1] == 4:
+ a = self._bgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ if role == qt.Qt.ForegroundRole:
+ if self._fgcolors is not None:
+ r, g, b = self._fgcolors[selection][0:3]
+ if self._fgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._fgcolors.shape[-1] == 4:
+ a = self._fgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ # no fg color given, use black or white
+ # based on luminosity threshold
+ elif self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ lum = 0.21 * r + 0.72 * g + 0.07 * b
+ if lum < 128:
+ return qt.QColor(qt.Qt.white)
+ else:
+ return qt.QColor(qt.Qt.black)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method
+ Return the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ return "%d" % section
+ if orientation == qt.Qt.Horizontal:
+ return "%d" % section
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not."""
+ if not self._editable:
+ return qt.QAbstractTableModel.flags(self, index)
+ return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable
+
+ def setData(self, index, value, role=None):
+ """QAbstractTableModel method to handle editing data.
+ Cast the new value into the same format as the array before editing
+ the array value."""
+ if index.isValid() and role == qt.Qt.EditRole:
+ try:
+ # cast value to same type as array
+ v = numpy.asscalar(
+ numpy.array(value, dtype=self._array.dtype))
+ except ValueError:
+ return False
+
+ selection = self._getIndexTuple(index.row(),
+ index.column())
+ self._array[selection] = v
+ self.dataChanged.emit(index, index)
+ return True
+ else:
+ return False
+
+ # Public methods
+ def setArrayData(self, data, copy=True,
+ perspective=None, editable=False):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: n-dimensional numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ :param bool copy: If *True* (default), a copy of the array is stored
+ and the original array is not modified if the table is edited.
+ If *False*, then the behavior depends on the data type:
+ if possible (if the original array is a proper numpy array)
+ a reference to the original array is used.
+ :param perspective: See documentation of :meth:`setPerspective`.
+ If None, the default perspective is the list of the first ``n-2``
+ dimensions, to view frames parallel to the last two axes.
+ :param bool editable: Flag to enable editing data. Default *False*.
+ """
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+ else:
+ self.reset()
+
+ if data is None:
+ # empty array
+ self._array = numpy.array([])
+ elif copy:
+ # copy requested (default)
+ self._array = numpy.array(data, copy=True)
+ elif not _is_array(data):
+ raise TypeError("data is not a proper array. Try setting" +
+ " copy=True to convert it into a numpy array" +
+ " (this will cause the data to be copied!)")
+ # # copy not requested, but necessary
+ # _logger.warning(
+ # "data is not an array-like object. " +
+ # "Data must be copied.")
+ # self._array = numpy.array(data, copy=True)
+ else:
+ # Copy explicitly disabled & data implements required attributes.
+ # We can use a reference.
+ self._array = data
+
+ # reset colors to None if new data shape is inconsistent
+ valid_color_shapes = (self._array.shape + (3,),
+ self._array.shape + (4,))
+ if self._bgcolors is not None:
+ if self._bgcolors.shape not in valid_color_shapes:
+ self._bgcolors = None
+ if self._fgcolors is not None:
+ if self._fgcolors.shape not in valid_color_shapes:
+ self._fgcolors = None
+
+ self.setEditable(editable)
+
+ self._index = [0 for _i in range((len(self._array.shape) - 2))]
+ self._perspective = tuple(perspective) if perspective is not None else\
+ tuple(range(0, len(self._array.shape) - 2))
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ # array must be RGB or RGBA
+ valid_shapes = (self._array.shape + (3,), self._array.shape + (4,))
+ errmsg = "Inconsistent shape for color array, should be %s or %s" % valid_shapes
+
+ if bgcolors is not None:
+ if not _is_array(bgcolors):
+ bgcolors = numpy.array(bgcolors)
+ assert bgcolors.shape in valid_shapes, errmsg
+
+ self._bgcolors = bgcolors
+
+ if fgcolors is not None:
+ if not _is_array(fgcolors):
+ fgcolors = numpy.array(fgcolors)
+ assert fgcolors.shape in valid_shapes, errmsg
+
+ self._fgcolors = fgcolors
+
+ def setEditable(self, editable):
+ """Set flags to make the data editable.
+
+ .. warning::
+
+ If the data is a reference to a h5py dataset open in read-only
+ mode, setting *editable=True* will fail and print a warning.
+
+ .. warning::
+
+ Making the data editable means that the underlying data structure
+ in this data model will be modified.
+ If the data is a reference to a public object (open with
+ ``copy=False``), this could have side effects. If it is a
+ reference to an HDF5 dataset, this means the file will be
+ modified.
+
+ :param bool editable: Flag to enable editing data.
+ :return: True if setting desired flag succeeded, False if it failed.
+ """
+ self._editable = editable
+ if hasattr(self._array, "file"):
+ if hasattr(self._array.file, "mode"):
+ if editable and self._array.file.mode == "r":
+ _logger.warning(
+ "Data is a HDF5 dataset open in read-only " +
+ "mode. Editing must be disabled.")
+ self._editable = False
+ return False
+ return True
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it
+ if *copy=False* is passed as parameter.
+
+ In case the shape was modified, to convert 0-D or 1-D data
+ into 2-D data, the original shape is restored in the returned data.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ data = self._array if not copy else numpy.array(self._array, copy=True)
+ return data
+
+ def setFrameIndex(self, index):
+ """Set the active slice index.
+
+ This method is only relevant to arrays with at least 3 dimensions.
+
+ :param index: Index of the active slice in the array.
+ In the general n-D case, this is a sequence of :math:`n - 2`
+ indices where the slice intersects the respective orthogonal axes.
+ :raise IndexError: If any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ shape = self._array.shape
+ if len(shape) < 3:
+ # index is ignored
+ return
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+ else:
+ self.reset()
+
+ if len(shape) == 3:
+ len_ = shape[self._perspective[0]]
+ # accept integers as index in the case of 3-D arrays
+ if not hasattr(index, "__len__"):
+ self._index = [index]
+ else:
+ self._index = index
+ if not 0 <= self._index[0] < len_:
+ raise ValueError("Index must be a positive integer " +
+ "lower than %d" % len_)
+ else:
+ # general n-D case
+ for i_, idx in enumerate(index):
+ if not 0 <= idx < shape[self._perspective[i_]]:
+ raise IndexError("Invalid index %d " % idx +
+ "not in range 0-%d" % (shape[i_] - 1))
+ self._index = index
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self._formatter:
+ return
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+
+ if self._formatter is not None:
+ self._formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self._formatter = formatter
+ if self._formatter is not None:
+ self._formatter.formatChanged.connect(self.__formatChanged)
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+ else:
+ self.reset()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self._formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+ def setPerspective(self, perspective):
+ """Set the perspective by defining a sequence listing all axes
+ orthogonal to the frame or 2-D slice to be visualized.
+
+ Alternatively, you can use :meth:`setFrameAxes` for the complementary
+ approach of specifying the two axes parallel to the frame.
+
+ In the 1-D or 2-D case, this parameter is irrelevant.
+
+ In the 3-D case, if the unit vectors describing
+ your axes are :math:`\vec{x}, \vec{y}, \vec{z}`, a perspective of 0
+ means you slices are parallel to :math:`\vec{y}\vec{z}`, 1 means they
+ are parallel to :math:`\vec{x}\vec{z}` and 2 means they
+ are parallel to :math:`\vec{x}\vec{y}`.
+
+ In the n-D case, this parameter is a sequence of :math:`n-2` axes
+ numbers.
+ For instance if you want to display 2-D frames whose axes are the
+ second and third dimensions of a 5-D array, set the perspective to
+ ``(0, 3, 4)``.
+
+ :param perspective: Sequence of dimensions/axes orthogonal to the
+ frames.
+ :raise: IndexError if any value in perspective is higher than the
+ number of dimensions minus one (first dimension is 0), or
+ if the number of values is different from the number of dimensions
+ minus two.
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "perspective is not relevant for 1D and 2D arrays")
+ return
+
+ if not hasattr(perspective, "__len__"):
+ # we can tolerate an integer for 3-D array
+ if n_dimensions == 3:
+ perspective = [perspective]
+ else:
+ raise ValueError("perspective must be a sequence of integers")
+
+ # ensure unicity of dimensions in perspective
+ perspective = tuple(set(perspective))
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+ else:
+ self.reset()
+
+ self._perspective = perspective
+
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the perspective by specifying the two axes parallel to the frame
+ to be visualised.
+
+ The complementary approach of defining the orthogonal axes can be used
+ with :meth:`setPerspective`.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ :raise: IndexError if axes are invalid
+ """
+ if row_axis > col_axis:
+ _logger.warning("The dimension of the row axis must be lower " +
+ "than the dimension of the column axis. Swapping.")
+ row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis)
+
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "Frame axes cannot be changed for 1D and 2D arrays")
+ return
+
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+ else:
+ self.reset()
+
+ self._perspective = perspective
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+ w = qt.QTableView()
+ d = numpy.random.normal(0, 1, (5, 1000, 1000))
+ for i in range(5):
+ d[i, :, :] += i * 10
+ m = ArrayTableModel(data=d)
+ w.setModel(m)
+ m.setFrameIndex(3)
+ # m.setArrayData(numpy.ones((100,)))
+ w.show()
+ app.exec_()
diff --git a/silx/gui/data/ArrayTableWidget.py b/silx/gui/data/ArrayTableWidget.py
new file mode 100644
index 0000000..ba3fa11
--- /dev/null
+++ b/silx/gui/data/ArrayTableWidget.py
@@ -0,0 +1,490 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget designed to display data arrays with any
+number of dimensions as 2D frames (images, slices) in a table view.
+The dimensions not displayed in the table can be browsed using improved
+sliders.
+
+The widget uses a TableView that relies on a custom abstract item
+model: :class:`silx.gui.data.ArrayTableModel`.
+"""
+from __future__ import division
+import sys
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableView
+from .ArrayTableModel import ArrayTableModel
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+class AxesSelector(qt.QWidget):
+ """Widget with two combo-boxes to select two dimensions among
+ all possible dimensions of an n-dimensional array.
+
+ The first combobox contains values from :math:`0` to :math:`n-2`.
+
+ The choices in the 2nd CB depend on the value selected in the first one.
+ If the value selected in the first CB is :math:`m`, the second one lets you
+ select values from :math:`m+1` to :math:`n-1`.
+
+ The two axes can be used to select the row axis and the column axis t
+ display a slice of the array data in a table view.
+ """
+ sigDimensionsChanged = qt.Signal(int, int)
+ """Signal emitted whenever one of the comboboxes is changed.
+ The signal carries the two selected dimensions."""
+
+ def __init__(self, parent=None, n=None):
+ qt.QWidget.__init__(self, parent)
+ self.layout = qt.QHBoxLayout(self)
+ self.layout.setContentsMargins(0, 2, 0, 2)
+ self.layout.setSpacing(10)
+
+ self.rowsCB = qt.QComboBox(self)
+ self.columnsCB = qt.QComboBox(self)
+
+ self.layout.addWidget(qt.QLabel("Rows dimension", self))
+ self.layout.addWidget(self.rowsCB)
+ self.layout.addWidget(qt.QLabel(" ", self))
+ self.layout.addWidget(qt.QLabel("Columns dimension", self))
+ self.layout.addWidget(self.columnsCB)
+ self.layout.addStretch(1)
+
+ self._slotsAreConnected = False
+ if n is not None:
+ self.setNDimensions(n)
+
+ def setNDimensions(self, n):
+ """Initialize combo-boxes depending on number of dimensions of array.
+ Initially, the rows dimension is the second-to-last one, and the
+ columns dimension is the last one.
+
+ Link the CBs together. MAke them emit a signal when their value is
+ changed.
+
+ :param int n: Number of dimensions of array
+ """
+ # remember the number of dimensions and the rows dimension
+ self.n = n
+ self._rowsDim = n - 2
+
+ # ensure slots are disconnected before (re)initializing widget
+ if self._slotsAreConnected:
+ self.rowsCB.currentIndexChanged.disconnect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+
+ self._clear()
+ self.rowsCB.addItems([str(i) for i in range(n - 1)])
+ self.rowsCB.setCurrentIndex(n - 2)
+ if n >= 1:
+ self.columnsCB.addItem(str(n - 1))
+ self.columnsCB.setCurrentIndex(0)
+
+ # reconnect slots
+ self.rowsCB.currentIndexChanged.connect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+ self._slotsAreConnected = True
+
+ # emit new dimensions
+ if n > 2:
+ self.sigDimensionsChanged.emit(n - 2, n - 1)
+
+ def setDimensions(self, row_dim, col_dim):
+ """Set the rows and columns dimensions.
+
+ The rows dimension must be lower than the columns dimension.
+
+ :param int row_dim: Rows dimension
+ :param int col_dim: Columns dimension
+ """
+ if row_dim >= col_dim:
+ raise IndexError("Row dimension must be lower than column dimension")
+ if not (0 <= row_dim < self.n - 1):
+ raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2))
+ if not (row_dim < col_dim <= self.n - 1):
+ raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1))
+
+ # set the rows dimension; this triggers an update of columnsCB
+ self.rowsCB.setCurrentIndex(row_dim)
+ # columnsCB first item is "row_dim + 1". So index of "col_dim" is
+ # col_dim - (row_dim + 1)
+ self.columnsCB.setCurrentIndex(col_dim - row_dim - 1)
+
+ def getDimensions(self):
+ """Return a 2-tuple of the rows dimension and the columns dimension.
+
+ :return: 2-tuple of axes numbers (row_dimension, col_dimension)
+ """
+ return self._getRowDim(), self._getColDim()
+
+ def _clear(self):
+ """Empty the combo-boxes"""
+ self.rowsCB.clear()
+ self.columnsCB.clear()
+
+ def _getRowDim(self):
+ """Get rows dimension, selected in :attr:`rowsCB`
+ """
+ # rows combobox contains elements "0", ..."n-2",
+ # so the selected dim is always equal to the index
+ return self.rowsCB.currentIndex()
+
+ def _getColDim(self):
+ """Get columns dimension, selected in :attr:`columnsCB`"""
+ # columns combobox contains elements "row_dim+1", "row_dim+2", ..., "n-1"
+ # so the selected dim is equal to row_dim + 1 + index
+ return self._rowsDim + 1 + self.columnsCB.currentIndex()
+
+ def _rowDimChanged(self):
+ """Update columns combobox when the rows dimension is changed.
+
+ Emit :attr:`sigDimensionsChanged`"""
+ old_col_dim = self._getColDim()
+ new_row_dim = self._getRowDim()
+
+ # clear cols CB
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+ self.columnsCB.clear()
+ # refill cols CB
+ for i in range(new_row_dim + 1, self.n):
+ self.columnsCB.addItem(str(i))
+
+ # keep previous col dimension if possible
+ new_col_cb_idx = old_col_dim - (new_row_dim + 1)
+ if new_col_cb_idx < 0:
+ # if row_dim is now greater than the previous col_dim,
+ # we select a new col_dim = row_dim + 1 (first element in cols CB)
+ new_col_cb_idx = 0
+ self.columnsCB.setCurrentIndex(new_col_cb_idx)
+
+ # reconnect slot
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+
+ self._rowsDim = new_row_dim
+
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+ def _colDimChanged(self):
+ """Emit :attr:`sigDimensionsChanged`"""
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+
+def _get_shape(array_like):
+ """Return shape of an array like object.
+
+ In case the object is a nested sequence (list of lists, tuples...),
+ the size of each dimension is assumed to be uniform, and is deduced from
+ the length of the first sequence.
+
+ :param array_like: Array like object: numpy array, hdf5 dataset,
+ multi-dimensional sequence
+ :return: Shape of array, as a tuple of integers
+ """
+ if hasattr(array_like, "shape"):
+ return array_like.shape
+
+ shape = []
+ subsequence = array_like
+ while hasattr(subsequence, "__len__"):
+ shape.append(len(subsequence))
+ subsequence = subsequence[0]
+
+ return tuple(shape)
+
+
+class ArrayTableWidget(qt.QWidget):
+ """This widget is designed to display data of 2D frames (images, slices)
+ in a table view. The widget can load any n-dimensional array, and display
+ any 2-D frame/slice in the array.
+
+ The index of the dimensions orthogonal to the displayed frame can be set
+ interactively using a browser widget (sliders, buttons and text entries).
+
+ To set the data, use :meth:`setArrayData`.
+ To select the perspective, use :meth:`setPerspective` or
+ use :meth:`setFrameAxes`.
+ To select the frame, use :meth:`setFrameIndex`.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: parent QWidget
+ :param labels: list of labels for each dimension of the array
+ """
+ qt.QWidget.__init__(self, parent)
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+
+ self.browserContainer = qt.QWidget(self)
+ self.browserLayout = qt.QGridLayout(self.browserContainer)
+ self.browserLayout.setContentsMargins(0, 0, 0, 0)
+ self.browserLayout.setSpacing(0)
+
+ self._dimensionLabelsText = []
+ """List of text labels sorted in the increasing order of the dimension
+ they apply to."""
+ self._browserLabels = []
+ """List of QLabel widgets."""
+ self._browserWidgets = []
+ """List of HorizontalSliderWithBrowser widgets."""
+
+ self.axesSelector = AxesSelector(self)
+
+ self.view = TableView(self)
+
+ self.mainLayout.addWidget(self.browserContainer)
+ self.mainLayout.addWidget(self.axesSelector)
+ self.mainLayout.addWidget(self.view)
+
+ self.model = ArrayTableModel(self)
+ self.view.setModel(self.model)
+
+ def setArrayData(self, data, labels=None, copy=True, editable=False):
+ """Set the data array. Update frame browsers and labels.
+
+ :param data: Numpy array or similar object (e.g. nested sequence,
+ h5py dataset...)
+ :param labels: list of labels for each dimension of the array, or
+ boolean ``True`` to use default labels ("dimension 0",
+ "dimension 1", ...). `None` to disable labels (default).
+ :param bool copy: If *True*, store a copy of *data* in the model. If
+ *False*, store a reference to *data* if possible (only possible if
+ *data* is a proper numpy array or an object that implements the
+ same methods).
+ :param bool editable: Flag to enable editing data. Default is *False*
+ """
+ self._data_shape = _get_shape(data)
+
+ n_widgets = len(self._browserWidgets)
+ n_dimensions = len(self._data_shape)
+
+ # Reset text of labels
+ self._dimensionLabelsText = []
+ for i in range(n_dimensions):
+ if labels in [True, 1]:
+ label_text = "Dimension %d" % i
+ elif labels is None or i >= len(labels):
+ label_text = ""
+ else:
+ label_text = labels[i]
+ self._dimensionLabelsText.append(label_text)
+
+ # not enough widgets, create new ones (we need n_dim - 2)
+ for i in range(n_widgets, n_dimensions - 2):
+ browser = HorizontalSliderWithBrowser(self.browserContainer)
+ self.browserLayout.addWidget(browser, i, 1)
+ self._browserWidgets.append(browser)
+ browser.valueChanged.connect(self._browserSlot)
+ browser.setEnabled(False)
+ browser.hide()
+
+ label = qt.QLabel(self.browserContainer)
+ self._browserLabels.append(label)
+ self.browserLayout.addWidget(label, i, 0)
+ label.hide()
+
+ n_widgets = len(self._browserWidgets)
+ for i in range(n_widgets):
+ label = self._browserLabels[i]
+ browser = self._browserWidgets[i]
+
+ if (i + 2) < n_dimensions:
+ label.setText(self._dimensionLabelsText[i])
+ browser.setRange(0, self._data_shape[i] - 1)
+ browser.setEnabled(True)
+ browser.show()
+ if labels is not None:
+ label.show()
+ else:
+ label.hide()
+ else:
+ browser.setEnabled(False)
+ browser.hide()
+ label.hide()
+
+ # set model
+ self.model.setArrayData(data, copy=copy, editable=editable)
+ # some linux distributions need this call
+ self.view.setModel(self.model)
+ if editable:
+ self.view.enableCut()
+ self.view.enablePaste()
+
+ # initialize & connect axesSelector
+ self.axesSelector.setNDimensions(n_dimensions)
+ self.axesSelector.sigDimensionsChanged.connect(self.setFrameAxes)
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ self.model.setArrayColors(bgcolors, fgcolors)
+
+ def displayAxesSelector(self, isVisible):
+ """Allow to display or hide the axes selector.
+
+ :param bool isVisible: True to display the axes selector.
+ """
+ self.axesSelector.setVisible(isVisible)
+
+ def setFrameIndex(self, index):
+ """Set the active slice/image index in the n-dimensional array.
+
+ A frame is a 2D array extracted from an array. This frame is
+ necessarily parallel to 2 axes, and orthogonal to all other axes.
+
+ The index of a frame is a sequence of indices along the orthogonal
+ axes, where the frame intersects the respective axis. The indices
+ are listed in the same order as the corresponding dimensions of the
+ data array.
+
+ For example, it the data array has 5 dimensions, and we are
+ considering frames whose parallel axes are the 2nd and 4th dimensions
+ of the array, the frame index will be a sequence of length 3
+ corresponding to the indices where the frame intersects the 1st, 3rd
+ and 5th axes.
+
+ :param index: Sequence of indices defining the active data slice in
+ a n-dimensional array. The sequence length is :math:`n-2`
+ :raise: IndexError if any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ self.model.setFrameIndex(index)
+
+ def _resetBrowsers(self, perspective):
+ """Adjust limits for browsers based on the perspective and the
+ size of the corresponding dimensions. Reset the index to 0.
+ Update the dimension in the labels.
+
+ :param perspective: Sequence of axes/dimensions numbers (0-based)
+ defining the axes orthogonal to the frame.
+ """
+ # for 3D arrays we can accept an int rather than a 1-tuple
+ if not hasattr(perspective, "__len__"):
+ perspective = [perspective]
+
+ # perspective must be sorted
+ perspective = sorted(perspective)
+
+ n_dimensions = len(self._data_shape)
+ for i in range(n_dimensions - 2):
+ browser = self._browserWidgets[i]
+ label = self._browserLabels[i]
+ browser.setRange(0, self._data_shape[perspective[i]] - 1)
+ browser.setValue(0)
+ label.setText(self._dimensionLabelsText[perspective[i]])
+
+ def setPerspective(self, perspective):
+ """Set the *perspective* by specifying which axes are orthogonal
+ to the frame.
+
+ For the opposite approach (defining parallel axes), use
+ :meth:`setFrameAxes` instead.
+
+ :param perspective: Sequence of unique axes numbers (0-based) defining
+ the orthogonal axes. For a n-dimensional array, the sequence
+ length is :math:`n-2`. The order is of the sequence is not taken
+ into account (the dimensions are displayed in increasing order
+ in the widget).
+ """
+ self.model.setPerspective(perspective)
+ self._resetBrowsers(perspective)
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the *perspective* by specifying which axes are parallel
+ to the frame.
+
+ For the opposite approach (defining orthogonal axes), use
+ :meth:`setPerspective` instead.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ """
+ self.model.setFrameAxes(row_axis, col_axis)
+ n_dimensions = len(self._data_shape)
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+ self._resetBrowsers(perspective)
+
+ def _browserSlot(self, value):
+ index = []
+ for browser in self._browserWidgets:
+ if browser.isEnabled():
+ index.append(browser.value())
+ self.setFrameIndex(index)
+ self.view.reset()
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it if
+ *copy=False* is passed as parameter.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: Numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ return self.model.getData(copy=copy)
+
+
+def main():
+ import numpy
+ a = qt.QApplication([])
+ d = numpy.random.normal(0, 1, (4, 5, 1000, 1000))
+ for j in range(4):
+ for i in range(5):
+ d[j, i, :, :] += i + 10 * j
+ w = ArrayTableWidget()
+ if "2" in sys.argv:
+ print("sending a single image")
+ w.setArrayData(d[0, 0])
+ elif "3" in sys.argv:
+ print("sending 5 images")
+ w.setArrayData(d[0])
+ else:
+ print("sending 4 * 5 images ")
+ w.setArrayData(d, labels=True)
+ w.show()
+ a.exec_()
+
+if __name__ == "__main__":
+ main()
diff --git a/silx/gui/data/DataViewer.py b/silx/gui/data/DataViewer.py
new file mode 100644
index 0000000..3a3ac64
--- /dev/null
+++ b/silx/gui/data/DataViewer.py
@@ -0,0 +1,464 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget designed to display data using to most adapted
+view from available ones from silx.
+"""
+from __future__ import division
+
+from silx.gui.data import DataViews
+from silx.gui.data.DataViews import _normalizeData
+import logging
+from silx.gui import qt
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+class DataViewer(qt.QFrame):
+ """Widget to display any kind of data
+
+ .. image:: img/DataViewer.png
+
+ The method :meth:`setData` allows to set any data to the widget. Mostly
+ `numpy.array` and `h5py.Dataset` are supported with adapted views. Other
+ data types are displayed using a text viewer.
+
+ A default view is automatically selected when a data is set. The method
+ :meth:`setDisplayMode` allows to change the view. To have a graphical tool
+ to select the view, prefer using the widget :class:`DataViewerFrame`.
+
+ The dimension of the input data and the expected dimension of the selected
+ view can differ. For example you can display an image (2D) from 4D
+ data. In this case a :class:`NumpyAxesSelector` is displayed to allow the
+ user to select the axis mapping and the slicing of other axes.
+
+ .. code-block:: python
+
+ import numpy
+ data = numpy.random.rand(500,500)
+ viewer = DataViewer()
+ viewer.setData(data)
+ viewer.setVisible(True)
+ """
+
+ EMPTY_MODE = 0
+ PLOT1D_MODE = 10
+ PLOT2D_MODE = 20
+ PLOT3D_MODE = 30
+ RAW_MODE = 40
+ RAW_ARRAY_MODE = 41
+ RAW_RECORD_MODE = 42
+ RAW_SCALAR_MODE = 43
+ STACK_MODE = 50
+ HDF5_MODE = 60
+
+ displayedViewChanged = qt.Signal(object)
+ """Emitted when the displayed view changes"""
+
+ dataChanged = qt.Signal()
+ """Emitted when the data changes"""
+
+ currentAvailableViewsChanged = qt.Signal()
+ """Emitted when the current available views (which support the current
+ data) change"""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param QWidget parent: The parent of the widget
+ """
+ super(DataViewer, self).__init__(parent)
+
+ self.__stack = qt.QStackedWidget(self)
+ self.__numpySelection = NumpyAxesSelector(self)
+ self.__numpySelection.selectedAxisChanged.connect(self.__numpyAxisChanged)
+ self.__numpySelection.selectionChanged.connect(self.__numpySelectionChanged)
+ self.__numpySelection.customAxisChanged.connect(self.__numpyCustomAxisChanged)
+
+ self.setLayout(qt.QVBoxLayout(self))
+ self.layout().addWidget(self.__stack, 1)
+
+ group = qt.QGroupBox(self)
+ group.setLayout(qt.QVBoxLayout())
+ group.layout().addWidget(self.__numpySelection)
+ group.setTitle("Axis selection")
+ self.__axisSelection = group
+
+ self.layout().addWidget(self.__axisSelection)
+
+ self.__currentAvailableViews = []
+ self.__currentView = None
+ self.__data = None
+ self.__useAxisSelection = False
+ self.__userSelectedView = None
+
+ self.__views = []
+ self.__index = {}
+ """store stack index for each views"""
+
+ self._initializeViews()
+
+ def _initializeViews(self):
+ """Inisialize the available views"""
+ views = self.createDefaultViews(self.__stack)
+ self.__views = list(views)
+ self.setDisplayMode(self.EMPTY_MODE)
+
+ def createDefaultViews(self, parent=None):
+ """Create and returns available views which can be displayed by default
+ by the data viewer. It is called internally by the widget. It can be
+ overwriten to provide a different set of viewers.
+
+ :param QWidget parent: QWidget parent of the views
+ :rtype: list[silx.gui.data.DataViews.DataView]
+ """
+ viewClasses = [
+ DataViews._EmptyView,
+ DataViews._Hdf5View,
+ DataViews._NXdataView,
+ DataViews._Plot1dView,
+ DataViews._Plot2dView,
+ DataViews._Plot3dView,
+ DataViews._RawView,
+ DataViews._StackView,
+ ]
+ views = []
+ for viewClass in viewClasses:
+ try:
+ view = viewClass(parent)
+ views.append(view)
+ except Exception:
+ _logger.warning("%s instantiation failed. View is ignored" % viewClass.__name__)
+ _logger.debug("Backtrace", exc_info=True)
+
+ return views
+
+ def clear(self):
+ """Clear the widget"""
+ self.setData(None)
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+ def __getStackIndex(self, view):
+ """Get the stack index containing the view.
+
+ :param silx.gui.data.DataViews.DataView view: The view
+ """
+ if view not in self.__index:
+ widget = view.getWidget()
+ index = self.__stack.addWidget(widget)
+ self.__index[view] = index
+ else:
+ index = self.__index[view]
+ return index
+
+ def __clearCurrentView(self):
+ """Clear the current selected view"""
+ view = self.__currentView
+ if view is not None:
+ view.clear()
+
+ def __numpyCustomAxisChanged(self, name, value):
+ view = self.__currentView
+ if view is not None:
+ view.setCustomAxisValue(name, value)
+
+ def __updateNumpySelectionAxis(self):
+ """
+ Update the numpy-selector according to the needed axis names
+ """
+ previous = self.__numpySelection.blockSignals(True)
+ self.__numpySelection.clear()
+ info = DataViews.DataInfo(self.__data)
+ axisNames = self.__currentView.axesNames(self.__data, info)
+ if info.isArray and self.__data is not None and len(axisNames) > 0:
+ self.__useAxisSelection = True
+ self.__numpySelection.setAxisNames(axisNames)
+ self.__numpySelection.setCustomAxis(self.__currentView.customAxisNames())
+ data = self.normalizeData(self.__data)
+ self.__numpySelection.setData(data)
+ if hasattr(data, "shape"):
+ isVisible = not (len(axisNames) == 1 and len(data.shape) == 1)
+ else:
+ isVisible = True
+ self.__axisSelection.setVisible(isVisible)
+ else:
+ self.__useAxisSelection = False
+ self.__axisSelection.setVisible(False)
+ self.__numpySelection.blockSignals(previous)
+
+ def __updateDataInView(self):
+ """
+ Update the views using the current data
+ """
+ if self.__useAxisSelection:
+ self.__displayedData = self.__numpySelection.selectedData()
+ else:
+ self.__displayedData = self.__data
+
+ qt.QTimer.singleShot(10, self.__setDataInView)
+
+ def __setDataInView(self):
+ self.__currentView.setData(self.__displayedData)
+
+ def setDisplayedView(self, view):
+ """Set the displayed view.
+
+ Change the displayed view according to the view itself.
+
+ :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data
+ """
+ self.__userSelectedView = view
+ self._setDisplayedView(view)
+
+ def _setDisplayedView(self, view):
+ """Internal set of the displayed view.
+
+ Change the displayed view according to the view itself.
+
+ :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data
+ """
+ if self.__currentView is view:
+ return
+ self.__clearCurrentView()
+ self.__currentView = view
+ self.__updateNumpySelectionAxis()
+ self.__updateDataInView()
+ stackIndex = self.__getStackIndex(self.__currentView)
+ if self.__currentView is not None:
+ self.__currentView.select()
+ self.__stack.setCurrentIndex(stackIndex)
+ self.displayedViewChanged.emit(view)
+
+ def getViewFromModeId(self, modeId):
+ """Returns the first available view which have the requested modeId.
+
+ :param int modeId: Requested mode id
+ :rtype: silx.gui.data.DataViews.DataView
+ """
+ for view in self.__views:
+ if view.modeId() == modeId:
+ return view
+ return view
+
+ def setDisplayMode(self, modeId):
+ """Set the displayed view using display mode.
+
+ Change the displayed view according to the requested mode.
+
+ :param int modeId: Display mode, one of
+
+ - `EMPTY_MODE`: display nothing
+ - `PLOT1D_MODE`: display the data as a curve
+ - `PLOT2D_MODE`: display the data as an image
+ - `PLOT3D_MODE`: display the data as an isosurface
+ - `RAW_MODE`: display the data as a table
+ - `STACK_MODE`: display the data as a stack of images
+ - `HDF5_MODE`: display the data as a table
+ """
+ try:
+ view = self.getViewFromModeId(modeId)
+ except KeyError:
+ raise ValueError("Display mode %s is unknown" % modeId)
+ self._setDisplayedView(view)
+
+ def displayedView(self):
+ """Returns the current displayed view.
+
+ :rtype: silx.gui.data.DataViews.DataView
+ """
+ return self.__currentView
+
+ def addView(self, view):
+ """Allow to add a view to the dataview.
+
+ If the current data support this view, it will be displayed.
+
+ :param DataView view: A dataview
+ """
+ self.__views.append(view)
+ # TODO It can be skipped if the view do not support the data
+ self.__updateAvailableViews()
+
+ def removeView(self, view):
+ """Allow to remove a view which was available from the dataview.
+
+ If the view was displayed, the widget will be updated.
+
+ :param DataView view: A dataview
+ """
+ self.__views.remove(view)
+ self.__stack.removeWidget(view.getWidget())
+ # invalidate the full index. It will be updated as expected
+ self.__index = {}
+
+ if self.__userSelectedView is view:
+ self.__userSelectedView = None
+
+ if view is self.__currentView:
+ self.__updateView()
+ else:
+ # TODO It can be skipped if the view is not part of the
+ # available views
+ self.__updateAvailableViews()
+
+ def __updateAvailableViews(self):
+ """
+ Update available views from the current data.
+ """
+ data = self.__data
+ # sort available views according to priority
+ info = DataViews.DataInfo(data)
+ priorities = [v.getDataPriority(data, info) for v in self.__views]
+ views = zip(priorities, self.__views)
+ views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views)
+ views = sorted(views, reverse=True)
+
+ # store available views
+ if len(views) == 0:
+ self.__setCurrentAvailableViews([])
+ available = []
+ else:
+ available = [v[1] for v in views]
+ self.__setCurrentAvailableViews(available)
+
+ def __updateView(self):
+ """Display the data using the widget which fit the best"""
+ data = self.__data
+
+ # update available views for this data
+ self.__updateAvailableViews()
+ available = self.__currentAvailableViews
+
+ # display the view with the most priority (the default view)
+ view = self.getDefaultViewFromAvailableViews(data, available)
+ self.__clearCurrentView()
+ try:
+ self._setDisplayedView(view)
+ except Exception as e:
+ # in case there is a problem to read the data, try to use a safe
+ # view
+ view = self.getSafeViewFromAvailableViews(data, available)
+ self._setDisplayedView(view)
+ raise e
+
+ def getSafeViewFromAvailableViews(self, data, available):
+ """Returns a view which is sure to display something without failing
+ on rendering.
+
+ :param object data: data which will be displayed
+ :param list[view] available: List of available views, from highest
+ priority to lowest.
+ :rtype: DataView
+ """
+ hdf5View = self.getViewFromModeId(DataViewer.HDF5_MODE)
+ if hdf5View in available:
+ return hdf5View
+ return self.getViewFromModeId(DataViewer.EMPTY_MODE)
+
+ def getDefaultViewFromAvailableViews(self, data, available):
+ """Returns the default view which will be used according to available
+ views.
+
+ :param object data: data which will be displayed
+ :param list[view] available: List of available views, from highest
+ priority to lowest.
+ :rtype: DataView
+ """
+ if len(available) > 0:
+ # returns the view with the highest priority
+ if self.__userSelectedView in available:
+ return self.__userSelectedView
+ self.__userSelectedView = None
+ view = available[0]
+ else:
+ # else returns the empty view
+ view = self.getViewFromModeId(DataViewer.EMPTY_MODE)
+ return view
+
+ def __setCurrentAvailableViews(self, availableViews):
+ """Set the current available viewa
+
+ :param List[DataView] availableViews: Current available viewa
+ """
+ self.__currentAvailableViews = availableViews
+ self.currentAvailableViewsChanged.emit()
+
+ def currentAvailableViews(self):
+ """Returns the list of available views for the current data
+
+ :rtype: List[DataView]
+ """
+ return self.__currentAvailableViews
+
+ def availableViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return self.__views
+
+ def setData(self, data):
+ """Set the data to view.
+
+ It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of
+ objects will be displayed as text rendering.
+
+ :param numpy.ndarray data: The data.
+ """
+ self.__data = data
+ self.__displayedData = None
+ self.__updateView()
+ self.__updateNumpySelectionAxis()
+ self.__updateDataInView()
+ self.dataChanged.emit()
+
+ def __numpyAxisChanged(self):
+ """
+ Called when axis selection of the numpy-selector changed
+ """
+ self.__clearCurrentView()
+
+ def __numpySelectionChanged(self):
+ """
+ Called when data selection of the numpy-selector changed
+ """
+ self.__updateDataInView()
+
+ def data(self):
+ """Returns the data"""
+ return self.__data
+
+ def displayMode(self):
+ """Returns the current display mode"""
+ return self.__currentView.modeId()
diff --git a/silx/gui/data/DataViewerFrame.py b/silx/gui/data/DataViewerFrame.py
new file mode 100644
index 0000000..b48fa7b
--- /dev/null
+++ b/silx/gui/data/DataViewerFrame.py
@@ -0,0 +1,186 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains a DataViewer with a view selector.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "10/04/2017"
+
+from silx.gui import qt
+from .DataViewer import DataViewer
+from .DataViewerSelector import DataViewerSelector
+
+
+class DataViewerFrame(qt.QWidget):
+ """
+ A :class:`DataViewer` with a view selector.
+
+ .. image:: img/DataViewerFrame.png
+
+ This widget provides the same API as :class:`DataViewer`. Therefore, for more
+ documentation, take a look at the documentation of the class
+ :class:`DataViewer`.
+
+ .. code-block:: python
+
+ import numpy
+ data = numpy.random.rand(500,500)
+ viewer = DataViewerFrame()
+ viewer.setData(data)
+ viewer.setVisible(True)
+
+ """
+
+ displayedViewChanged = qt.Signal(object)
+ """Emitted when the displayed view changes"""
+
+ dataChanged = qt.Signal()
+ """Emitted when the data changes"""
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent:
+ """
+ super(DataViewerFrame, self).__init__(parent)
+
+ class _DataViewer(DataViewer):
+ """Overwrite methods to avoid to create views while the instance
+ is not created. `initializeViews` have to be called manually."""
+
+ def _initializeViews(self):
+ pass
+
+ def initializeViews(self):
+ """Avoid to create views while the instance is not created."""
+ super(_DataViewer, self)._initializeViews()
+
+ self.__dataViewer = _DataViewer(self)
+ # initialize views when `self.__dataViewer` is set
+ self.__dataViewer.initializeViews()
+ self.__dataViewer.setFrameShape(qt.QFrame.StyledPanel)
+ self.__dataViewer.setFrameShadow(qt.QFrame.Sunken)
+ self.__dataViewerSelector = DataViewerSelector(self, self.__dataViewer)
+ self.__dataViewerSelector.setFlat(True)
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(self.__dataViewer, 1)
+ layout.addWidget(self.__dataViewerSelector)
+ self.setLayout(layout)
+
+ self.__dataViewer.dataChanged.connect(self.__dataChanged)
+ self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged)
+
+ def __dataChanged(self):
+ """Called when the data is changed"""
+ self.dataChanged.emit()
+
+ def __displayedViewChanged(self, view):
+ """Called when the displayed view changes"""
+ self.displayedViewChanged.emit(view)
+
+ def availableViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return self.__dataViewer.availableViews()
+
+ def currentAvailableViews(self):
+ """Returns the list of available views for the current data
+
+ :rtype: List[DataView]
+ """
+ return self.__dataViewer.currentAvailableViews()
+
+ def createDefaultViews(self, parent=None):
+ """Create and returns available views which can be displayed by default
+ by the data viewer. It is called internally by the widget. It can be
+ overwriten to provide a different set of viewers.
+
+ :param QWidget parent: QWidget parent of the views
+ :rtype: list[silx.gui.data.DataViews.DataView]
+ """
+ return self.__dataViewer.createDefaultViews(parent)
+
+ def addView(self, view):
+ """Allow to add a view to the dataview.
+
+ If the current data support this view, it will be displayed.
+
+ :param DataView view: A dataview
+ """
+ return self.__dataViewer.addView(view)
+
+ def removeView(self, view):
+ """Allow to remove a view which was available from the dataview.
+
+ If the view was displayed, the widget will be updated.
+
+ :param DataView view: A dataview
+ """
+ return self.__dataViewer.removeView(view)
+
+ def setData(self, data):
+ """Set the data to view.
+
+ It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of
+ objects will be displayed as text rendering.
+
+ :param numpy.ndarray data: The data.
+ """
+ self.__dataViewer.setData(data)
+
+ def data(self):
+ """Returns the data"""
+ return self.__dataViewer.data()
+
+ def setDisplayedView(self, view):
+ self.__dataViewer.setDisplayedView(view)
+
+ def displayedView(self):
+ return self.__dataViewer.displayedView()
+
+ def displayMode(self):
+ return self.__dataViewer.displayMode()
+
+ def setDisplayMode(self, modeId):
+ """Set the displayed view using display mode.
+
+ Change the displayed view according to the requested mode.
+
+ :param int modeId: Display mode, one of
+
+ - `EMPTY_MODE`: display nothing
+ - `PLOT1D_MODE`: display the data as a curve
+ - `PLOT2D_MODE`: display the data as an image
+ - `TEXT_MODE`: display the data as a text
+ - `ARRAY_MODE`: display the data as a table
+ """
+ return self.__dataViewer.setDisplayMode(modeId)
diff --git a/silx/gui/data/DataViewerSelector.py b/silx/gui/data/DataViewerSelector.py
new file mode 100644
index 0000000..32cc636
--- /dev/null
+++ b/silx/gui/data/DataViewerSelector.py
@@ -0,0 +1,153 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget to be able to select the available view
+of the DataViewer.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+import weakref
+import functools
+from silx.gui import qt
+from silx.gui.data.DataViewer import DataViewer
+import silx.utils.weakref
+
+
+class DataViewerSelector(qt.QWidget):
+ """Widget to be able to select a custom view from the DataViewer"""
+
+ def __init__(self, parent=None, dataViewer=None):
+ """Constructor
+
+ :param QWidget parent: The parent of the widget
+ :param DataViewer dataViewer: The connected `DataViewer`
+ """
+ super(DataViewerSelector, self).__init__(parent)
+
+ self.__group = None
+ self.__buttons = {}
+ self.__buttonDummy = None
+ self.__dataViewer = None
+
+ if dataViewer is not None:
+ self.setDataViewer(dataViewer)
+
+ def __updateButtons(self):
+ if self.__group is not None:
+ self.__group.deleteLater()
+ self.__buttons = {}
+ self.__buttonDummy = None
+
+ self.__group = qt.QButtonGroup(self)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ if self.__dataViewer is None:
+ return
+
+ iconSize = qt.QSize(16, 16)
+
+ for view in self.__dataViewer.availableViews():
+ label = view.label()
+ icon = view.icon()
+ button = qt.QPushButton(label)
+ button.setIcon(icon)
+ button.setIconSize(iconSize)
+ button.setCheckable(True)
+ # the weak objects are needed to be able to destroy the widget safely
+ weakView = weakref.ref(view)
+ weakMethod = silx.utils.weakref.WeakMethodProxy(self.__setDisplayedView)
+ callback = functools.partial(weakMethod, weakView)
+ button.clicked.connect(callback)
+ self.layout().addWidget(button)
+ self.__group.addButton(button)
+ self.__buttons[view] = button
+
+ button = qt.QPushButton("Dummy")
+ button.setCheckable(True)
+ button.setVisible(False)
+ self.layout().addWidget(button)
+ self.__group.addButton(button)
+ self.__buttonDummy = button
+
+ self.layout().addStretch(1)
+
+ self.__updateButtonsVisibility()
+ self.__displayedViewChanged(self.__dataViewer.displayedView())
+
+ def setDataViewer(self, dataViewer):
+ """Define the dataviewer connected to this status bar
+
+ :param DataViewer dataViewer: The connected `DataViewer`
+ """
+ if self.__dataViewer is dataViewer:
+ return
+ if self.__dataViewer is not None:
+ self.__dataViewer.dataChanged.disconnect(self.__updateButtonsVisibility)
+ self.__dataViewer.displayedViewChanged.disconnect(self.__displayedViewChanged)
+ self.__dataViewer = dataViewer
+ if self.__dataViewer is not None:
+ self.__dataViewer.dataChanged.connect(self.__updateButtonsVisibility)
+ self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged)
+ self.__updateButtons()
+
+ def setFlat(self, isFlat):
+ """Set the flat state of all the buttons.
+
+ :param bool isFlat: True to display the buttons flatten.
+ """
+ for b in self.__buttons.values():
+ b.setFlat(isFlat)
+ self.__buttonDummy.setFlat(isFlat)
+
+ def __displayedViewChanged(self, view):
+ """Called on displayed view changeS"""
+ selectedButton = self.__buttons.get(view, self.__buttonDummy)
+ selectedButton.setChecked(True)
+
+ def __setDisplayedView(self, refView, clickEvent=None):
+ """Display a data using the requested view
+
+ :param DataView view: Requested view
+ :param clickEvent: Event sent by the clicked event
+ """
+ if self.__dataViewer is None:
+ return
+ view = refView()
+ if view is None:
+ return
+ self.__dataViewer.setDisplayedView(view)
+
+ def __updateButtonsVisibility(self):
+ """Called on data changed"""
+ if self.__dataViewer is None:
+ for b in self.__buttons.values():
+ b.setVisible(False)
+ else:
+ availableViews = set(self.__dataViewer.currentAvailableViews())
+ for view, button in self.__buttons.items():
+ button.setVisible(view in availableViews)
diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py
new file mode 100644
index 0000000..d8d605a
--- /dev/null
+++ b/silx/gui/data/DataViews.py
@@ -0,0 +1,988 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a views used by :class:`silx.gui.data.DataViewer`.
+"""
+
+import logging
+import numbers
+import numpy
+
+import silx.io
+from silx.gui import qt, icons
+from silx.gui.data.TextFormatter import TextFormatter
+from silx.io import nxdata
+from silx.gui.hdf5 import H5Node
+from silx.io.nxdata import NXdata
+
+__authors__ = ["V. Valls", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+_logger = logging.getLogger(__name__)
+
+
+# DataViewer modes
+EMPTY_MODE = 0
+PLOT1D_MODE = 10
+PLOT2D_MODE = 20
+PLOT3D_MODE = 30
+RAW_MODE = 40
+RAW_ARRAY_MODE = 41
+RAW_RECORD_MODE = 42
+RAW_SCALAR_MODE = 43
+STACK_MODE = 50
+HDF5_MODE = 60
+
+
+def _normalizeData(data):
+ """Returns a normalized data.
+
+ If the data embed a numpy data or a dataset it is returned.
+ Else returns the input data."""
+ if isinstance(data, H5Node):
+ return data.h5py_object
+ return data
+
+
+def _normalizeComplex(data):
+ """Returns a normalized complex data.
+
+ If the data is a numpy data with complex, returns the
+ absolute value.
+ Else returns the input data."""
+ if hasattr(data, "dtype"):
+ isComplex = numpy.issubdtype(data.dtype, numpy.complex)
+ else:
+ isComplex = isinstance(data, numbers.Complex)
+ if isComplex:
+ data = numpy.absolute(data)
+ return data
+
+
+class DataInfo(object):
+ """Store extracted information from a data"""
+
+ def __init__(self, data):
+ data = self.normalizeData(data)
+ self.isArray = False
+ self.interpretation = None
+ self.isNumeric = False
+ self.isComplex = False
+ self.isRecord = False
+ self.isNXdata = False
+ self.shape = tuple()
+ self.dim = 0
+
+ if data is None:
+ return
+
+ if silx.io.is_group(data) and nxdata.is_valid_nxdata(data):
+ self.isNXdata = True
+ nxd = nxdata.NXdata(data)
+
+ if isinstance(data, numpy.ndarray):
+ self.isArray = True
+ elif silx.io.is_dataset(data) and data.shape != tuple():
+ self.isArray = True
+ else:
+ self.isArray = False
+
+ if silx.io.is_dataset(data):
+ self.interpretation = data.attrs.get("interpretation", None)
+ elif self.isNXdata:
+ self.interpretation = nxd.interpretation
+ else:
+ self.interpretation = None
+
+ if hasattr(data, "dtype"):
+ self.isNumeric = numpy.issubdtype(data.dtype, numpy.number)
+ self.isRecord = data.dtype.fields is not None
+ self.isComplex = numpy.issubdtype(data.dtype, numpy.complex)
+ elif self.isNXdata:
+ self.isNumeric = numpy.issubdtype(nxd.signal.dtype,
+ numpy.number)
+ self.isComplex = numpy.issubdtype(nxd.signal.dtype, numpy.complex)
+ else:
+ self.isNumeric = isinstance(data, numbers.Number)
+ self.isComplex = isinstance(data, numbers.Complex)
+ self.isRecord = False
+
+ if hasattr(data, "shape"):
+ self.shape = data.shape
+ elif self.isNXdata:
+ self.shape = nxd.signal.shape
+ else:
+ self.shape = tuple()
+ self.dim = len(self.shape)
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+
+class DataView(object):
+ """Holder for the data view."""
+
+ UNSUPPORTED = -1
+ """Priority returned when the requested data can't be displayed by the
+ view."""
+
+ def __init__(self, parent, modeId=None, icon=None, label=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ self.__parent = parent
+ self.__widget = None
+ self.__modeId = modeId
+ if label is None:
+ label = self.__class__.__name__
+ self.__label = label
+ if icon is None:
+ icon = qt.QIcon()
+ self.__icon = icon
+
+ def icon(self):
+ """Returns the default icon"""
+ return self.__icon
+
+ def label(self):
+ """Returns the default label"""
+ return self.__label
+
+ def modeId(self):
+ """Returns the mode id"""
+ return self.__modeId
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+ def customAxisNames(self):
+ """Returns names of axes which can be custom by the user and provided
+ to the view."""
+ return []
+
+ def setCustomAxisValue(self, name, value):
+ """
+ Set the value of a custom axis
+
+ :param str name: Name of the custom axis
+ :param int value: Value of the custom axis
+ """
+ pass
+
+ def isWidgetInitialized(self):
+ """Returns true if the widget is already initialized.
+ """
+ return self.__widget is not None
+
+ def select(self):
+ """Called when the view is selected to display the data.
+ """
+ return
+
+ def getWidget(self):
+ """Returns the widget hold in the view and displaying the data.
+
+ :returns: qt.QWidget
+ """
+ if self.__widget is None:
+ self.__widget = self.createWidget(self.__parent)
+ return self.__widget
+
+ def createWidget(self, parent):
+ """Create the the widget displaying the data
+
+ :param qt.QWidget parent: Parent of the widget
+ :returns: qt.QWidget
+ """
+ raise NotImplementedError()
+
+ def clear(self):
+ """Clear the data from the view"""
+ return None
+
+ def setData(self, data):
+ """Set the data displayed by the view
+
+ :param data: Data to display
+ :type data: numpy.ndarray or h5py.Dataset
+ """
+ return None
+
+ def axesNames(self, data, info):
+ """Returns names of the expected axes of the view, according to the
+ input data.
+
+ :param data: Data to display
+ :type data: numpy.ndarray or h5py.Dataset
+ :param DataInfo info: Pre-computed information on the data
+ :rtype: list[str]
+ """
+ return []
+
+ def getDataPriority(self, data, info):
+ """
+ Returns the priority of using this view according to a data.
+
+ - `UNSUPPORTED` means this view can't display this data
+ - `1` means this view can display the data
+ - `100` means this view should be used for this data
+ - `1000` max value used by the views provided by silx
+ - ...
+
+ :param object data: The data to check
+ :param DataInfo info: Pre-computed information on the data
+ :rtype: int
+ """
+ return DataView.UNSUPPORTED
+
+ def __lt__(self, other):
+ return str(self) < str(other)
+
+
+class CompositeDataView(DataView):
+ """Data view which can display a data using different view according to
+ the kind of the data."""
+
+ def __init__(self, parent, modeId=None, icon=None, label=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ super(CompositeDataView, self).__init__(parent, modeId, icon, label)
+ self.__views = {}
+ self.__currentView = None
+
+ def addView(self, dataView):
+ """Add a new dataview to the available list."""
+ self.__views[dataView] = None
+
+ def getBestView(self, data, info):
+ """Returns the best view according to priorities."""
+ info = DataInfo(data)
+ views = [(v.getDataPriority(data, info), v) for v in self.__views.keys()]
+ views = filter(lambda t: t[0] > DataView.UNSUPPORTED, views)
+ views = sorted(views, reverse=True)
+
+ if len(views) == 0:
+ return None
+ elif views[0][0] == DataView.UNSUPPORTED:
+ return None
+ else:
+ return views[0][1]
+
+ def customAxisNames(self):
+ if self.__currentView is None:
+ return
+ return self.__currentView.customAxisNames()
+
+ def setCustomAxisValue(self, name, value):
+ if self.__currentView is None:
+ return
+ self.__currentView.setCustomAxisValue(name, value)
+
+ def __updateDisplayedView(self):
+ widget = self.getWidget()
+ if self.__currentView is None:
+ return
+
+ # load the widget if it is not yet done
+ index = self.__views[self.__currentView]
+ if index is None:
+ w = self.__currentView.getWidget()
+ index = widget.addWidget(w)
+ self.__views[self.__currentView] = index
+ if widget.currentIndex() != index:
+ widget.setCurrentIndex(index)
+ self.__currentView.select()
+
+ def select(self):
+ self.__updateDisplayedView()
+ if self.__currentView is not None:
+ self.__currentView.select()
+
+ def createWidget(self, parent):
+ return qt.QStackedWidget()
+
+ def clear(self):
+ for v in self.__views.keys():
+ v.clear()
+
+ def setData(self, data):
+ if self.__currentView is None:
+ return
+ self.__updateDisplayedView()
+ self.__currentView.setData(data)
+
+ def axesNames(self, data, info):
+ view = self.getBestView(data, info)
+ self.__currentView = view
+ return view.axesNames(data, info)
+
+ def getDataPriority(self, data, info):
+ view = self.getBestView(data, info)
+ self.__currentView = view
+ if view is None:
+ return DataView.UNSUPPORTED
+ else:
+ return view.getDataPriority(data, info)
+
+
+class _EmptyView(DataView):
+ """Dummy view to display nothing"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=EMPTY_MODE)
+
+ def axesNames(self, data, info):
+ return []
+
+ def createWidget(self, parent):
+ return qt.QLabel(parent)
+
+ def getDataPriority(self, data, info):
+ return DataView.UNSUPPORTED
+
+
+class _Plot1dView(DataView):
+ """View displaying data using a 1d plot"""
+
+ def __init__(self, parent):
+ super(_Plot1dView, self).__init__(
+ parent=parent,
+ modeId=PLOT1D_MODE,
+ label="Curve",
+ icon=icons.getQIcon("view-1d"))
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ return plot.Plot1D(parent=parent)
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().addCurve(legend="data",
+ x=range(len(data)),
+ y=data,
+ resetzoom=self.__resetZoomNextTime)
+ self.__resetZoomNextTime = True
+
+ def axesNames(self, data, info):
+ return ["y"]
+
+ def getDataPriority(self, data, info):
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 1:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "spectrum":
+ return 1000
+ if info.dim == 2 and info.shape[0] == 1:
+ return 210
+ if info.dim == 1:
+ return 100
+ else:
+ return 10
+
+
+class _Plot2dView(DataView):
+ """View displaying data using a 2d plot"""
+
+ def __init__(self, parent):
+ super(_Plot2dView, self).__init__(
+ parent=parent,
+ modeId=PLOT2D_MODE,
+ label="Image",
+ icon=icons.getQIcon("view-2d"))
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ widget = plot.Plot2D(parent=parent)
+ widget.setKeepDataAspectRatio(True)
+ widget.setGraphXLabel('X')
+ widget.setGraphYLabel('Y')
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().addImage(legend="data",
+ data=data,
+ resetzoom=self.__resetZoomNextTime)
+ self.__resetZoomNextTime = False
+
+ def axesNames(self, data, info):
+ return ["y", "x"]
+
+ def getDataPriority(self, data, info):
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "image":
+ return 1000
+ if info.dim == 2:
+ return 200
+ else:
+ return 190
+
+
+class _Plot3dView(DataView):
+ """View displaying data using a 3d plot"""
+
+ def __init__(self, parent):
+ super(_Plot3dView, self).__init__(
+ parent=parent,
+ modeId=PLOT3D_MODE,
+ label="Cube",
+ icon=icons.getQIcon("view-3d"))
+ try:
+ import silx.gui.plot3d #noqa
+ except ImportError:
+ _logger.warning("Plot3dView is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ raise
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from silx.gui.plot3d import ScalarFieldView
+ from silx.gui.plot3d import SFViewParamTree
+
+ plot = ScalarFieldView.ScalarFieldView(parent)
+ plot.setAxesLabels(*reversed(self.axesNames(None, None)))
+ plot.addIsosurface(
+ lambda data: numpy.mean(data) + numpy.std(data), '#FF0000FF')
+
+ # Create a parameter tree for the scalar field view
+ options = SFViewParamTree.TreeView(plot)
+ options.setSfView(plot)
+
+ # Add the parameter tree to the main window in a dock widget
+ dock = qt.QDockWidget()
+ dock.setWidget(options)
+ plot.addDockWidget(qt.Qt.RightDockWidgetArea, dock)
+
+ return plot
+
+ def clear(self):
+ self.getWidget().setData(None)
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ plot = self.getWidget()
+ plot.setData(data)
+ self.__resetZoomNextTime = False
+
+ def axesNames(self, data, info):
+ return ["z", "y", "x"]
+
+ def getDataPriority(self, data, info):
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 3:
+ return DataView.UNSUPPORTED
+ if min(data.shape) < 2:
+ return DataView.UNSUPPORTED
+ if info.dim == 3:
+ return 100
+ else:
+ return 10
+
+
+class _ArrayView(DataView):
+ """View displaying data using a 2d table"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_ARRAY_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+ widget = ArrayTableWidget(parent)
+ widget.displayAxesSelector(False)
+ return widget
+
+ def clear(self):
+ self.getWidget().setArrayData(numpy.array([[]]))
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setArrayData(data)
+
+ def axesNames(self, data, info):
+ return ["col", "row"]
+
+ def getDataPriority(self, data, info):
+ if data is None or not info.isArray or info.isRecord:
+ return DataView.UNSUPPORTED
+ if info.dim < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation in ["scalar", "scaler"]:
+ return 1000
+ return 500
+
+
+class _StackView(DataView):
+ """View displaying data using a stack of images"""
+
+ def __init__(self, parent):
+ super(_StackView, self).__init__(
+ parent=parent,
+ modeId=STACK_MODE,
+ label="Image stack",
+ icon=icons.getQIcon("view-2d-stack"))
+ self.__resetZoomNextTime = True
+
+ def customAxisNames(self):
+ return ["depth"]
+
+ def setCustomAxisValue(self, name, value):
+ if name == "depth":
+ self.getWidget().setFrameNumber(value)
+ else:
+ raise Exception("Unsupported axis")
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ widget = plot.StackView(parent=parent)
+ widget.setKeepDataAspectRatio(True)
+ widget.setLabels(self.axesNames(None, None))
+ # hide default option panel
+ widget.setOptionVisible(False)
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setStack(stack=data, reset=self.__resetZoomNextTime)
+ self.__resetZoomNextTime = False
+
+ def axesNames(self, data, info):
+ return ["depth", "y", "x"]
+
+ def getDataPriority(self, data, info):
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 3:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "image":
+ return 500
+ return 90
+
+
+class _ScalarView(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_SCALAR_MODE)
+
+ def createWidget(self, parent):
+ widget = qt.QTextEdit(parent)
+ widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+ widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
+ self.__formatter = TextFormatter(parent)
+ return widget
+
+ def clear(self):
+ self.getWidget().setText("")
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ if silx.io.is_dataset(data):
+ data = data[()]
+ text = self.__formatter.toString(data)
+ self.getWidget().setText(text)
+
+ def axesNames(self, data, info):
+ return []
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if data is None:
+ return DataView.UNSUPPORTED
+ if silx.io.is_group(data):
+ return DataView.UNSUPPORTED
+ return 2
+
+
+class _RecordView(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_RECORD_MODE)
+
+ def createWidget(self, parent):
+ from .RecordTableView import RecordTableView
+ widget = RecordTableView(parent)
+ widget.setWordWrap(False)
+ return widget
+
+ def clear(self):
+ self.getWidget().setArrayData(None)
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ widget = self.getWidget()
+ widget.setArrayData(data)
+ widget.resizeRowsToContents()
+ widget.resizeColumnsToContents()
+
+ def axesNames(self, data, info):
+ return ["data"]
+
+ def getDataPriority(self, data, info):
+ if info.isRecord:
+ return 40
+ if data is None or not info.isArray:
+ return DataView.UNSUPPORTED
+ if info.dim == 1:
+ if info.interpretation in ["scalar", "scaler"]:
+ return 1000
+ if info.shape[0] == 1:
+ return 510
+ return 500
+ elif info.isRecord:
+ return 40
+ return DataView.UNSUPPORTED
+
+
+class _Hdf5View(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ super(_Hdf5View, self).__init__(
+ parent=parent,
+ modeId=HDF5_MODE,
+ label="HDF5",
+ icon=icons.getQIcon("view-hdf5"))
+
+ def createWidget(self, parent):
+ from .Hdf5TableView import Hdf5TableView
+ widget = Hdf5TableView(parent)
+ return widget
+
+ def clear(self):
+ widget = self.getWidget()
+ widget.setData(None)
+
+ def setData(self, data):
+ widget = self.getWidget()
+ widget.setData(data)
+
+ def axesNames(self, data, info):
+ return []
+
+ def getDataPriority(self, data, info):
+ widget = self.getWidget()
+ if widget.isSupportedData(data):
+ return 1
+ else:
+ return DataView.UNSUPPORTED
+
+
+class _RawView(CompositeDataView):
+ """View displaying data as raw data.
+
+ This implementation use a 2d-array view, or a record array view, or a
+ raw text output.
+ """
+
+ def __init__(self, parent):
+ super(_RawView, self).__init__(
+ parent=parent,
+ modeId=RAW_MODE,
+ label="Raw",
+ icon=icons.getQIcon("view-raw"))
+ self.addView(_ScalarView(parent))
+ self.addView(_ArrayView(parent))
+ self.addView(_RecordView(parent))
+
+
+class _NXdataScalarView(DataView):
+ """DataView using a table view for displaying NXdata scalars:
+ 0-D signal or n-D signal with *@interpretation=scalar*"""
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def createWidget(self, parent):
+ from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+ widget = ArrayTableWidget(parent)
+ # widget.displayAxesSelector(False)
+ return widget
+
+ def axesNames(self, data, info):
+ return ["col", "row"]
+
+ def clear(self):
+ self.getWidget().setArrayData(numpy.array([[]]),
+ labels=True)
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ signal = NXdata(data).signal
+ self.getWidget().setArrayData(signal,
+ labels=True)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isNXdata:
+ nxd = NXdata(data)
+ if nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]:
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataCurveView(DataView):
+ """DataView using a Plot1D for displaying NXdata curves:
+ 1-D signal or n-D signal with *@interpretation=spectrum*.
+
+ It also handles basic scatter plots:
+ a 1-D signal with one axis whose values are not monotonically increasing.
+ """
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayCurvePlot
+ widget = ArrayCurvePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return []
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = NXdata(data)
+ signal_name = data.attrs["signal"]
+ group_name = data.name
+ if nxd.axes_names[-1] is not None:
+ x_errors = nxd.get_axis_errors(nxd.axes_names[-1])
+ else:
+ x_errors = None
+
+ self.getWidget().setCurveData(nxd.signal, nxd.axes[-1],
+ yerror=nxd.errors, xerror=x_errors,
+ ylabel=signal_name, xlabel=nxd.axes_names[-1],
+ title="NXdata group " + group_name)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isNXdata:
+ nxd = NXdata(data)
+ if nxd.is_x_y_value_scatter or nxd.is_unsupported_scatter:
+ return DataView.UNSUPPORTED
+ if nxd.signal_is_1d and \
+ not nxd.interpretation in ["scalar", "scaler"]:
+ return 100
+ if nxd.interpretation == "spectrum":
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataXYVScatterView(DataView):
+ """DataView using a Plot1D for displaying NXdata 3D scatters as
+ a scatter of coloured points (1-D signal with 2 axes)"""
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayCurvePlot
+ widget = ArrayCurvePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return []
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = NXdata(data)
+ signal_name = data.attrs["signal"]
+ # signal_errors = nx.errors # not supported
+ group_name = data.name
+ x_axis, y_axis = nxd.axes[-2:]
+
+ x_label, y_label = nxd.axes_names[-2:]
+ if x_label is not None:
+ x_errors = nxd.get_axis_errors(x_label)
+ else:
+ x_errors = None
+
+ if y_label is not None:
+ y_errors = nxd.get_axis_errors(y_label)
+ else:
+ y_errors = None
+
+ self.getWidget().setCurveData(y_axis, x_axis, values=nxd.signal,
+ yerror=y_errors, xerror=x_errors,
+ ylabel=signal_name, xlabel=x_label,
+ title="NXdata group " + group_name)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isNXdata:
+ if NXdata(data).is_x_y_value_scatter:
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataImageView(DataView):
+ """DataView using a Plot2D for displaying NXdata images:
+ 2-D signal or n-D signals with *@interpretation=spectrum*."""
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayImagePlot
+ widget = ArrayImagePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ return []
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = NXdata(data)
+ signal_name = data.attrs["signal"]
+ group_name = data.name
+ y_axis, x_axis = nxd.axes[-2:]
+ y_label, x_label = nxd.axes_names[-2:]
+
+ self.getWidget().setImageData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis,
+ signal_name=signal_name, xlabel=x_label, ylabel=y_label,
+ title="NXdata group %s: %s" % (group_name, signal_name))
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isNXdata:
+ nxd = NXdata(data)
+ if nxd.signal_is_2d:
+ if nxd.interpretation not in ["scalar", "spectrum", "scaler"]:
+ return 100
+ if nxd.interpretation == "image":
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataStackView(DataView):
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayStackPlot
+ widget = ArrayStackPlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ return []
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = NXdata(data)
+ signal_name = data.attrs["signal"]
+ group_name = data.name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+
+ self.getWidget().setStackData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title="NXdata group %s: %s" % (group_name, signal_name))
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isNXdata:
+ nxd = NXdata(data)
+ if nxd.signal_ndim >= 3:
+ if nxd.interpretation not in ["scalar", "scaler",
+ "spectrum", "image"]:
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataView(CompositeDataView):
+ """Composite view displaying NXdata groups using the most adequate
+ widget depending on the dimensionality."""
+ def __init__(self, parent):
+ super(_NXdataView, self).__init__(
+ parent=parent,
+ label="NXdata",
+ icon=icons.getQIcon("view-nexus"))
+
+ self.addView(_NXdataScalarView(parent))
+ self.addView(_NXdataCurveView(parent))
+ self.addView(_NXdataXYVScatterView(parent))
+ self.addView(_NXdataImageView(parent))
+ self.addView(_NXdataStackView(parent))
diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py
new file mode 100644
index 0000000..5d79907
--- /dev/null
+++ b/silx/gui/data/Hdf5TableView.py
@@ -0,0 +1,414 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+import functools
+import os.path
+import logging
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+import silx.gui.hdf5
+from silx.gui.widgets import HierarchicalTableView
+
+_logger = logging.getLogger(__name__)
+
+
+class _CellData(object):
+ """Store a table item
+ """
+ def __init__(self, value=None, isHeader=False, span=None):
+ """
+ Constructor
+
+ :param str value: Label of this property
+ :param bool isHeader: True if the cell is an header
+ :param tuple span: Tuple of row, column span
+ """
+ self.__value = value
+ self.__isHeader = isHeader
+ self.__span = span
+
+ def isHeader(self):
+ """Returns true if the property is a sub-header title.
+
+ :rtype: bool
+ """
+ return self.__isHeader
+
+ def value(self):
+ """Returns the value of the item.
+ """
+ return self.__value
+
+ def span(self):
+ """Returns the span size of the cell.
+
+ :rtype: tuple
+ """
+ return self.__span
+
+
+class _TableData(object):
+ """Modelize a table with header, row and column span.
+
+ It is mostly defined as a row based table.
+ """
+
+ def __init__(self, columnCount):
+ """Constructor.
+
+ :param int columnCount: Define the number of column of the table
+ """
+ self.__colCount = columnCount
+ self.__data = []
+
+ def rowCount(self):
+ """Returns the number of rows.
+
+ :rtype: int
+ """
+ return len(self.__data)
+
+ def columnCount(self):
+ """Returns the number of columns.
+
+ :rtype: int
+ """
+ return self.__colCount
+
+ def clear(self):
+ """Remove all the cells of the table"""
+ self.__data = []
+
+ def cellAt(self, row, column):
+ """Returns the cell at the row column location. Else None if there is
+ nothing.
+
+ :rtype: _CellData
+ """
+ if row < 0:
+ return None
+ if column < 0:
+ return None
+ if row >= len(self.__data):
+ return None
+ cells = self.__data[row]
+ if column >= len(cells):
+ return None
+ return cells[column]
+
+ def addHeaderRow(self, headerLabel):
+ """Append the table with header on the full row.
+
+ :param str headerLabel: label of the header.
+ """
+ item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount))
+ self.__data.append([item])
+
+ def addHeaderValueRow(self, headerLabel, value):
+ """Append the table with a row using the first column as an header and
+ other cells as a single cell for the value.
+
+ :param str headerLabel: label of the header.
+ :param object value: value to store.
+ """
+ header = _CellData(value=headerLabel, isHeader=True)
+ value = _CellData(value=value, span=(1, self.__colCount))
+ self.__data.append([header, value])
+
+ def addRow(self, *args):
+ """Append the table with a row using arguments for each cells
+
+ :param list[object] args: List of cell values for the row
+ """
+ row = []
+ for value in args:
+ if not isinstance(value, _CellData):
+ value = _CellData(value=value)
+ row.append(value)
+ self.__data.append(row)
+
+
+class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
+ """This data model provides access to HDF5 node content (File, Group,
+ Dataset). Main info, like name, file, attributes... are displayed
+ """
+
+ def __init__(self, parent=None, data=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Parent object
+ :param object data: An h5py-like object (file, group or dataset)
+ """
+ super(Hdf5TableModel, self).__init__(parent)
+
+ self.__obj = None
+ self.__data = _TableData(columnCount=4)
+ self.__formatter = None
+ formatter = TextFormatter(self)
+ self.setFormatter(formatter)
+ self.setObject(data)
+
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ return self.__data.rowCount()
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ return self.__data.columnCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ cell = self.__data.cellAt(index.row(), index.column())
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell.span()
+ elif role == self.IsHeaderRole:
+ return cell.isHeader()
+ elif role == qt.Qt.DisplayRole:
+ value = cell.value()
+ if callable(value):
+ value = value(self.__obj)
+ return str(value)
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def isSupportedObject(self, h5pyObject):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ isSupported = False
+ isSupported = isSupported or silx.io.is_group(h5pyObject)
+ isSupported = isSupported or silx.io.is_dataset(h5pyObject)
+ isSupported = isSupported or isinstance(h5pyObject, silx.gui.hdf5.H5Node)
+ return isSupported
+
+ def setObject(self, h5pyObject):
+ """Set the h5py-like object exposed by the model
+
+ :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+
+ if h5pyObject is None or self.isSupportedObject(h5pyObject):
+ self.__obj = h5pyObject
+ else:
+ _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject))
+ self.__initProperties()
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+ else:
+ self.reset()
+
+ def __initProperties(self):
+ """Initialize the list of available properties according to the defined
+ h5py-like object."""
+ self.__data.clear()
+ if self.__obj is None:
+ return
+
+ obj = self.__obj
+
+ hdf5obj = obj
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ hdf5obj = obj.h5py_object
+
+ if silx.io.is_file(hdf5obj):
+ objectType = "File"
+ elif silx.io.is_group(hdf5obj):
+ objectType = "Group"
+ elif silx.io.is_dataset(hdf5obj):
+ objectType = "Dataset"
+ else:
+ objectType = obj.__class__.__name__
+ self.__data.addHeaderRow(headerLabel="HDF5 %s" % objectType)
+ self.__data.addHeaderRow(headerLabel="Path info")
+
+ self.__data.addHeaderValueRow("basename", lambda x: os.path.basename(x.name))
+ self.__data.addHeaderValueRow("name", lambda x: x.name)
+ if silx.io.is_file(obj):
+ self.__data.addHeaderValueRow("filename", lambda x: x.filename)
+
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ # helpful informations if the object come from an HDF5 tree
+ self.__data.addHeaderValueRow("local_basename", lambda x: x.local_basename)
+ self.__data.addHeaderValueRow("local_name", lambda x: x.local_name)
+ self.__data.addHeaderValueRow("local_filename", lambda x: x.local_file.filename)
+
+ if hasattr(obj, "dtype"):
+ self.__data.addHeaderRow(headerLabel="Data info")
+ self.__data.addHeaderValueRow("dtype", lambda x: x.dtype)
+ if hasattr(obj, "shape"):
+ self.__data.addHeaderValueRow("shape", lambda x: x.shape)
+ if hasattr(obj, "size"):
+ self.__data.addHeaderValueRow("size", lambda x: x.size)
+ if hasattr(obj, "chunks") and obj.chunks is not None:
+ self.__data.addHeaderValueRow("chunks", lambda x: x.chunks)
+
+ # relative to compression
+ # h5py expose compression, compression_opts but are not initialized
+ # for external plugins, then we use id
+ # h5py also expose fletcher32 and shuffle attributes, but it is also
+ # part of the filters
+ if hasattr(obj, "shape") and hasattr(obj, "id"):
+ dcpl = obj.id.get_create_plist()
+ if dcpl.get_nfilters() > 0:
+ self.__data.addHeaderRow(headerLabel="Compression info")
+ pos = _CellData(value="Position", isHeader=True)
+ hdf5id = _CellData(value="HDF5 ID", isHeader=True)
+ name = _CellData(value="Name", isHeader=True)
+ options = _CellData(value="Options", isHeader=True)
+ self.__data.addRow(pos, hdf5id, name, options)
+ for index in range(dcpl.get_nfilters()):
+ callback = lambda index, dataIndex, x: self.__get_filter_info(x, index)[dataIndex]
+ pos = _CellData(value=functools.partial(callback, index, 0))
+ hdf5id = _CellData(value=functools.partial(callback, index, 1))
+ name = _CellData(value=functools.partial(callback, index, 2))
+ options = _CellData(value=functools.partial(callback, index, 3))
+ self.__data.addRow(pos, hdf5id, name, options)
+
+ if hasattr(obj, "attrs"):
+ if len(obj.attrs) > 0:
+ self.__data.addHeaderRow(headerLabel="Attributes")
+ for key in sorted(obj.attrs.keys()):
+ callback = lambda key, x: self.__formatter.toString(x.attrs[key])
+ self.__data.addHeaderValueRow(headerLabel=key, value=functools.partial(callback, key))
+
+ def __get_filter_info(self, dataset, filterIndex):
+ """Get a tuple of readable info from dataset filters
+
+ :param h5py.Dataset dataset: A h5py dataset
+ :param int filterId:
+ """
+ try:
+ dcpl = dataset.id.get_create_plist()
+ info = dcpl.get_filter(filterIndex)
+ filterId, _flags, cdValues, name = info
+ name = self.__formatter.toString(name)
+ options = " ".join([self.__formatter.toString(i) for i in cdValues])
+ return (filterIndex, filterId, name, options)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ return [filterIndex, None, None, None]
+
+ def object(self):
+ """Returns the internal object modelized.
+
+ :rtype: An h5py-like object
+ """
+ return self.__obj
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+ else:
+ self.reset()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+
+class Hdf5TableView(HierarchicalTableView.HierarchicalTableView):
+ """A widget to display metadata about a HDF5 node using a table."""
+
+ def __init__(self, parent=None):
+ super(Hdf5TableView, self).__init__(parent)
+ self.setModel(Hdf5TableModel(self))
+
+ def isSupportedData(self, data):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ return self.model().isSupportedObject(data)
+
+ def setData(self, data):
+ """Set the h5py-like object exposed by the model
+
+ :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ self.model().setObject(data)
+ header = self.horizontalHeader()
+ if qt.qVersion() < "5.0":
+ setResizeMode = header.setResizeMode
+ else:
+ setResizeMode = header.setSectionResizeMode
+ setResizeMode(0, qt.QHeaderView.Fixed)
+ setResizeMode(1, qt.QHeaderView.Stretch)
+ header.setStretchLastSection(True)
diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py
new file mode 100644
index 0000000..343c7f9
--- /dev/null
+++ b/silx/gui/data/NXdataWidgets.py
@@ -0,0 +1,523 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines widgets used by _NXdataView.
+"""
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/03/2017"
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.plot import Plot1D, Plot2D, StackView
+
+from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
+
+
+class ArrayCurvePlot(qt.QWidget):
+ """
+ Widget for plotting a curve from a multi-dimensional signal array
+ and a 1D axis array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last dimension must have the same length as
+ the axis array.
+
+ The widget provides sliders to select indices on the first (n - 1)
+ dimensions of the signal array, and buttons to add/replace selected
+ curves to the plot.
+
+ This widget also handles simple 2D or 3D scatter plots (third dimension
+ displayed as colour of points).
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayCurvePlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ self.__signal_errors = None
+ self.__axis = None
+ self.__axis_name = None
+ self.__axis_errors = None
+ self.__values = None
+
+ self.__first_curve_added = False
+
+ self._plot = Plot1D(self)
+ self._plot.setDefaultColormap( # for scatters
+ {"name": "viridis",
+ "vmin": 0., "vmax": 1., # ignored (autoscale) but mandatory
+ "normalization": "linear",
+ "autoscale": True})
+
+ self.selectorDock = qt.QDockWidget("Data selector", self._plot)
+ # not closable
+ self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable |
+ qt.QDockWidget.DockWidgetFloatable)
+ self._selector = NumpyAxesSelector(self.selectorDock)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+ self.selectorDock.setWidget(self._selector)
+ self._plot.addTabbedDockWidget(self.selectorDock)
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0)
+
+ self.setLayout(layout)
+
+ def setCurveData(self, y, x=None, values=None,
+ yerror=None, xerror=None,
+ ylabel=None, xlabel=None, title=None):
+ """
+
+ :param y: dataset to be represented by the y (vertical) axis.
+ For a scatter, this must be a 1D array and x and values must be
+ 1-D arrays of the same size.
+ In other cases, it can be a n-D array whose last dimension must
+ have the same length as x (and values must be None)
+ :param x: 1-D dataset used as the curve's x values. If provided,
+ its lengths must be equal to the length of the last dimension of
+ ``y`` (and equal to the length of ``value``, for a scatter plot).
+ :param values: Values, to be provided for a x-y-value scatter plot.
+ This will be used to compute the color map and assign colors
+ to the points.
+ :param yerror: 1-D dataset of errors for y, or None
+ :param xerror: 1-D dataset of errors for x, or None
+ :param ylabel: Label for Y axis
+ :param xlabel: Label for X axis
+ :param title: Graph title
+ """
+ self.__signal = y
+ self.__signal_name = ylabel
+ self.__signal_errors = yerror
+ self.__axis = x
+ self.__axis_name = xlabel
+ self.__axis_errors = xerror
+ self.__values = values
+
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateCurve)
+ self.__selector_is_connected = False
+ self._selector.setData(y)
+ self._selector.setAxisNames([ylabel or "Y"])
+
+ if len(y.shape) < 2:
+ self.selectorDock.hide()
+ else:
+ self.selectorDock.show()
+
+ self._plot.setGraphTitle(title or "")
+ self._plot.setGraphXLabel(self.__axis_name or "X")
+ self._plot.setGraphYLabel(self.__signal_name or "Y")
+ self._updateCurve()
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateCurve)
+ self.__selector_is_connected = True
+
+ def _updateCurve(self):
+ y = self._selector.selectedData()
+ x = self.__axis
+ if x is None:
+ x = numpy.arange(len(y))
+ elif numpy.isscalar(x) or len(x) == 1:
+ # constant axis
+ x = x * numpy.ones_like(y)
+ elif len(x) == 2 and len(y) != 2:
+ # linear calibration a + b * x
+ x = x[0] + x[1] * numpy.arange(len(y))
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ if self.__signal_errors is not None:
+ y_errors = self.__signal_errors[self._selector.selection()]
+ else:
+ y_errors = None
+
+ self._plot.remove(kind=("curve", "scatter"))
+
+ # values: x-y-v scatter
+ if self.__values is not None:
+ self._plot.addScatter(x, y, self.__values,
+ legend=legend,
+ xerror=self.__axis_errors,
+ yerror=y_errors)
+
+ # x monotonically increasing: curve
+ elif numpy.all(numpy.diff(x) > 0):
+ self._plot.addCurve(x, y, legend=legend,
+ xerror=self.__axis_errors,
+ yerror=y_errors)
+
+ # scatter
+ else:
+ self._plot.addScatter(x, y, value=numpy.ones_like(y),
+ legend=legend,
+ xerror=self.__axis_errors,
+ yerror=y_errors)
+ self._plot.resetZoom()
+ self._plot.setGraphXLabel(self.__axis_name)
+ self._plot.setGraphYLabel(self.__signal_name)
+
+ def clear(self):
+ self._plot.clear()
+
+
+class ArrayImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayImagePlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = Plot2D(self)
+ self._plot.setDefaultColormap(
+ {"name": "viridis",
+ "vmin": 0., "vmax": 1., # ignored (autoscale) but mandatory
+ "normalization": "linear",
+ "autoscale": True})
+
+ self.selectorDock = qt.QDockWidget("Data selector", self._plot)
+ # not closable
+ self.selectorDock.setFeatures(qt.QDockWidget.DockWidgetMovable |
+ qt.QDockWidget.DockWidgetFloatable)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self.selectorDock)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._legend)
+ self.selectorDock.setWidget(self._selector)
+ self._plot.addTabbedDockWidget(self.selectorDock)
+
+ self.setLayout(layout)
+
+ def setImageData(self, signal,
+ x_axis=None, y_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 2 dimensions are used as the
+ image's values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames([ylabel or "Y", xlabel or "X"])
+
+ if len(signal.shape) < 3:
+ self.selectorDock.hide()
+ else:
+ self.selectorDock.show()
+
+ self._plot.setGraphTitle(title or "")
+ self._plot.setGraphXLabel(self.__x_axis_name or "X")
+ self._plot.setGraphYLabel(self.__y_axis_name or "Y")
+
+ self._updateImage()
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateImage)
+ self.__selector_is_connected = True
+
+ def _updateImage(self):
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ img = self._selector.selectedData()
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(img.shape[-1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((img.shape[-1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(img.shape[-1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(img.shape[-2])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((img.shape[-2], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(img.shape[-2]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.remove(kind=("scatter", "image"))
+ if xcalib.is_affine() and ycalib.is_affine():
+ # regular image
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ origin = (xorigin, yorigin)
+ scale = (xscale, yscale)
+
+ self._plot.addImage(img, legend=legend,
+ origin=origin, scale=scale)
+ else:
+ scatterx, scattery = numpy.meshgrid(x_axis, y_axis)
+ self._plot.addScatter(numpy.ravel(scatterx),
+ numpy.ravel(scattery),
+ numpy.ravel(img),
+ legend=legend)
+ self._plot.setGraphXLabel(self.__x_axis_name)
+ self._plot.setGraphYLabel(self.__y_axis_name)
+ self._plot.resetZoom()
+
+ def clear(self):
+ self._plot.clear()
+
+
+class ArrayStackPlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a stack of images.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayStackPlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ self._stack_view = StackView(self)
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._stack_view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def setStackData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateStack)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames([ylabel or "Y", xlabel or "X", zlabel or "Z"])
+
+ self._stack_view.setGraphTitle(title or "")
+ # by default, the z axis is the image position (dimension not plotted)
+ self._stack_view.setGraphXLabel(self.__x_axis_name or "X")
+ self._stack_view.setGraphYLabel(self.__y_axis_name or "Y")
+
+ self._updateStack()
+
+ ndims = len(signal.shape)
+ self._stack_view.setFirstStackDimension(ndims - 3)
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if ndims > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateStack)
+ self.__selector_is_connected = True
+
+ @staticmethod
+ def _get_origin_scale(axis):
+ """Assuming axis is a regularly spaced 1D array,
+ return a tuple (origin, scale) where:
+ - origin = axis[0]
+ - scale = (axis[n-1] - axis[0]) / (n -1)
+ :param axis: 1D numpy array
+ :return: Tuple (axis[0], (axis[-1] - axis[0]) / (len(axis) - 1))
+ """
+ return axis[0], (axis[-1] - axis[0]) / (len(axis) - 1)
+
+ def _updateStack(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ stk = self._selector.selectedData()
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ calibrations = []
+ for axis in [z_axis, y_axis, x_axis]:
+
+ if axis is None:
+ calibrations.append(NoCalibration())
+ elif len(axis) == 2:
+ calibrations.append(
+ LinearCalibration(y_intercept=axis[0],
+ slope=axis[1]))
+ else:
+ calibrations.append(ArrayCalibration(axis))
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ self._stack_view.setStack(stk, calibrations=calibrations)
+ self._stack_view.setLabels(
+ labels=[self.__z_axis_name,
+ self.__y_axis_name,
+ self.__x_axis_name])
+
+ def clear(self):
+ self._stack_view.clear()
diff --git a/silx/gui/data/NumpyAxesSelector.py b/silx/gui/data/NumpyAxesSelector.py
new file mode 100644
index 0000000..f4641da
--- /dev/null
+++ b/silx/gui/data/NumpyAxesSelector.py
@@ -0,0 +1,468 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget able to convert a numpy array from n-dimensions
+to a numpy array with less dimensions.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+import numpy
+import functools
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+from silx.gui import qt
+import silx.utils.weakref
+
+
+class _Axis(qt.QWidget):
+ """Widget displaying an axis.
+
+ It allows to display and scroll in the axis, and provide a widget to
+ map the axis with a named axis (the one from the view).
+ """
+
+ valueChanged = qt.Signal(int)
+ """Emitted when the location on the axis change."""
+
+ axisNameChanged = qt.Signal(object)
+ """Emitted when the user change the name of the axis."""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param parent: Parent of the widget
+ """
+ super(_Axis, self).__init__(parent)
+ self.__axisNumber = None
+ self.__customAxisNames = set([])
+ self.__label = qt.QLabel(self)
+ self.__axes = qt.QComboBox(self)
+ self.__axes.currentIndexChanged[int].connect(self.__axisMappingChanged)
+ self.__slider = HorizontalSliderWithBrowser(self)
+ self.__slider.valueChanged[int].connect(self.__sliderValueChanged)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__label)
+ layout.addWidget(self.__axes)
+ layout.addWidget(self.__slider, 10000)
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ def slider(self):
+ """Returns the slider used to display axes location.
+
+ :rtype: HorizontalSliderWithBrowser
+ """
+ return self.__slider
+
+ def setAxis(self, number, position, size):
+ """Set axis information.
+
+ :param int number: The number of the axis (from the original numpy
+ array)
+ :param int position: The current position in the axis (for a slicing)
+ :param int size: The size of this axis (0..n)
+ """
+ self.__label.setText("Dimension %s" % number)
+ self.__axisNumber = number
+ self.__slider.setMaximum(size - 1)
+
+ def axisNumber(self):
+ """Returns the axis number.
+
+ :rtype: int
+ """
+ return self.__axisNumber
+
+ def setAxisName(self, axisName):
+ """Set the current used axis name.
+
+ If this name is not available an exception is raised. An empty string
+ means that no name is selected.
+
+ :param str axisName: The new name of the axis
+ :raise ValueError: When the name is not available
+ """
+ if axisName == "" and self.__axes.count() == 0:
+ self.__axes.setCurrentIndex(-1)
+ self.__updateSliderVisibility()
+ for index in range(self.__axes.count()):
+ name = self.__axes.itemData(index)
+ if name == axisName:
+ self.__axes.setCurrentIndex(index)
+ self.__updateSliderVisibility()
+ return
+ raise ValueError("Axis name '%s' not found", axisName)
+
+ def axisName(self):
+ """Returns the selected axis name.
+
+ If no names are selected, an empty string is retruned.
+
+ :rtype: str
+ """
+ index = self.__axes.currentIndex()
+ if index == -1:
+ return ""
+ return self.__axes.itemData(index)
+
+ def setAxisNames(self, axesNames):
+ """Set the available list of names for the axis.
+
+ :param list[str] axesNames: List of available names
+ """
+ self.__axes.clear()
+ previous = self.__axes.blockSignals(True)
+ self.__axes.addItem(" ", "")
+ for axis in axesNames:
+ self.__axes.addItem(axis, axis)
+ self.__axes.blockSignals(previous)
+ self.__updateSliderVisibility()
+
+ def setCustomAxis(self, axesNames):
+ """Set the available list of named axis which can be set to a value.
+
+ :param list[str] axesNames: List of customable axis names
+ """
+ self.__customAxisNames = set(axesNames)
+ self.__updateSliderVisibility()
+
+ def __axisMappingChanged(self, index):
+ """Called when the selected name change.
+
+ :param int index: Selected index
+ """
+ self.__updateSliderVisibility()
+ name = self.axisName()
+ self.axisNameChanged.emit(name)
+
+ def __updateSliderVisibility(self):
+ """Update the visibility of the slider according to axis names and
+ customable axis names."""
+ name = self.axisName()
+ isVisible = name == "" or name in self.__customAxisNames
+ self.__slider.setVisible(isVisible)
+
+ def value(self):
+ """Returns the current selected position in the axis.
+
+ :rtype: int
+ """
+ return self.__slider.value()
+
+ def __sliderValueChanged(self, value):
+ """Called when the selected position in the axis change.
+
+ :param int value: Position of the axis
+ """
+ self.valueChanged.emit(value)
+
+ def setNamedAxisSelectorVisibility(self, visible):
+ """Hide or show the named axis combobox.
+ If both the selector and the slider are hidden,
+ hide the entire widget.
+
+ :param visible: boolean
+ """
+ self.__axes.setVisible(visible)
+ name = self.axisName()
+
+ if not visible and name != "":
+ self.setVisible(False)
+ else:
+ self.setVisible(True)
+
+
+class NumpyAxesSelector(qt.QWidget):
+ """Widget to select a view from a numpy array.
+
+ .. image:: img/NumpyAxesSelector.png
+
+ The widget is set with an input data using :meth:`setData`, and a requested
+ output dimension using :meth:`setAxisNames`.
+
+ Widgets are provided to selected expected input axis, and a slice on the
+ non-selected axis.
+
+ The final selected array can be reached using the getter
+ :meth:`selectedData`, and the event `selectionChanged`.
+
+ If the input data is a HDF5 Dataset, the selected output data will be a
+ new numpy array.
+ """
+
+ dataChanged = qt.Signal()
+ """Emitted when the input data change"""
+
+ selectedAxisChanged = qt.Signal()
+ """Emitted when the selected axis change"""
+
+ selectionChanged = qt.Signal()
+ """Emitted when the selected data change"""
+
+ customAxisChanged = qt.Signal(str, int)
+ """Emitted when a custom axis change"""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param parent: Parent of the widget
+ """
+ super(NumpyAxesSelector, self).__init__(parent)
+
+ self.__data = None
+ self.__selectedData = None
+ self.__selection = tuple()
+ self.__axis = []
+ self.__axisNames = []
+ self.__customAxisNames = set([])
+ self.__namedAxesVisibility = True
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.setLayout(layout)
+
+ def clear(self):
+ """Clear the widget."""
+ self.setData(None)
+
+ def setAxisNames(self, axesNames):
+ """Set the axis names of the output selected data.
+
+ Axis names are defined from slower to faster axis.
+
+ The size of the list will constrain the dimension of the resulting
+ array.
+
+ :param list[str] axesNames: List of string identifying axis names
+ """
+ self.__axisNames = list(axesNames)
+ delta = len(self.__axis) - len(self.__axisNames)
+ if delta < 0:
+ delta = 0
+ for index, axis in enumerate(self.__axis):
+ previous = axis.blockSignals(True)
+ axis.setAxisNames(self.__axisNames)
+ if index >= delta and index - delta < len(self.__axisNames):
+ axis.setAxisName(self.__axisNames[index - delta])
+ else:
+ axis.setAxisName("")
+ axis.blockSignals(previous)
+ self.__updateSelectedData()
+
+ def setCustomAxis(self, axesNames):
+ """Set the available list of named axis which can be set to a value.
+
+ :param list[str] axesNames: List of customable axis names
+ """
+ self.__customAxisNames = set(axesNames)
+ for axis in self.__axis:
+ axis.setCustomAxis(self.__customAxisNames)
+
+ def setData(self, data):
+ """Set the input data unsed by the widget.
+
+ :param numpy.ndarray data: The input data
+ """
+ if self.__data is not None:
+ # clean up
+ for widget in self.__axis:
+ self.layout().removeWidget(widget)
+ widget.deleteLater()
+ self.__axis = []
+
+ self.__data = data
+
+ if data is not None:
+ # create expected axes
+ dimensionNumber = len(data.shape)
+ delta = dimensionNumber - len(self.__axisNames)
+ for index in range(dimensionNumber):
+ axis = _Axis(self)
+ axis.setAxis(index, 0, data.shape[index])
+ axis.setAxisNames(self.__axisNames)
+ axis.setCustomAxis(self.__customAxisNames)
+ if index >= delta and index - delta < len(self.__axisNames):
+ axis.setAxisName(self.__axisNames[index - delta])
+ # this weak method was expected to be able to delete sub widget
+ callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisValueChanged), axis)
+ axis.valueChanged.connect(callback)
+ # this weak method was expected to be able to delete sub widget
+ callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisNameChanged), axis)
+ axis.axisNameChanged.connect(callback)
+ axis.setNamedAxisSelectorVisibility(self.__namedAxesVisibility)
+ self.layout().addWidget(axis)
+ self.__axis.append(axis)
+ self.__normalizeAxisGeometry()
+
+ self.dataChanged.emit()
+ self.__updateSelectedData()
+
+ def __normalizeAxisGeometry(self):
+ """Update axes geometry to align all axes components together."""
+ if len(self.__axis) <= 0:
+ return
+ lineEditWidth = max([a.slider().lineEdit().minimumSize().width() for a in self.__axis])
+ limitWidth = max([a.slider().limitWidget().minimumSizeHint().width() for a in self.__axis])
+ for a in self.__axis:
+ a.slider().lineEdit().setFixedWidth(lineEditWidth)
+ a.slider().limitWidget().setFixedWidth(limitWidth)
+
+ def __axisValueChanged(self, axis, value):
+ name = axis.axisName()
+ if name in self.__customAxisNames:
+ self.customAxisChanged.emit(name, value)
+ else:
+ self.__updateSelectedData()
+
+ def __axisNameChanged(self, axis, name):
+ """Called when an axis name change.
+
+ :param _Axis axis: The changed axis
+ :param str name: The new name of the axis
+ """
+ names = [x.axisName() for x in self.__axis]
+ missingName = set(self.__axisNames) - set(names) - set("")
+ if len(missingName) == 0:
+ missingName = None
+ elif len(missingName) == 1:
+ missingName = list(missingName)[0]
+ else:
+ raise Exception("Unexpected state")
+
+ axisChanged = True
+
+ if axis.axisName() == "":
+ # set the removed label to another widget if it is possible
+ availableWidget = None
+ for widget in self.__axis:
+ if widget is axis:
+ continue
+ if widget.axisName() == "":
+ availableWidget = widget
+ break
+ if availableWidget is None:
+ # If there is no other solution we set the name at the same place
+ axisChanged = False
+ availableWidget = axis
+ previous = availableWidget.blockSignals(True)
+ availableWidget.setAxisName(missingName)
+ availableWidget.blockSignals(previous)
+ else:
+ # there is a duplicated name somewhere
+ # we swap it with the missing name or with nothing
+ dupWidget = None
+ for widget in self.__axis:
+ if widget is axis:
+ continue
+ if widget.axisName() == axis.axisName():
+ dupWidget = widget
+ break
+ if missingName is None:
+ missingName = ""
+ previous = dupWidget.blockSignals(True)
+ dupWidget.setAxisName(missingName)
+ dupWidget.blockSignals(previous)
+
+ if self.__data is None:
+ return
+ if axisChanged:
+ self.selectedAxisChanged.emit()
+ self.__updateSelectedData()
+
+ def __updateSelectedData(self):
+ """Update the selected data according to the state of the widget.
+
+ It fires a `selectionChanged` event.
+ """
+ if self.__data is None:
+ if self.__selectedData is not None:
+ self.__selectedData = None
+ self.__selection = tuple()
+ self.selectionChanged.emit()
+ return
+
+ selection = []
+ axisNames = []
+ for slider in self.__axis:
+ name = slider.axisName()
+ if name == "":
+ selection.append(slider.value())
+ else:
+ selection.append(slice(None))
+ axisNames.append(name)
+
+ self.__selection = tuple(selection)
+ # get a view with few fixed dimensions
+ # with a h5py dataset, it create a copy
+ # TODO we can reuse the same memory in case of a copy
+ view = self.__data[self.__selection]
+
+ # order axis as expected
+ source = []
+ destination = []
+ order = []
+ for index, name in enumerate(self.__axisNames):
+ destination.append(index)
+ source.append(axisNames.index(name))
+ for _, s in sorted(zip(destination, source)):
+ order.append(s)
+ view = numpy.transpose(view, order)
+
+ self.__selectedData = view
+ self.selectionChanged.emit()
+
+ def data(self):
+ """Returns the input data.
+
+ :rtype: numpy.ndarray
+ """
+ return self.__data
+
+ def selectedData(self):
+ """Returns the output data.
+
+ :rtype: numpy.ndarray
+ """
+ return self.__selectedData
+
+ def selection(self):
+ """Returns the selection tuple used to slice the data.
+
+ :rtype: tuple
+ """
+ return self.__selection
+
+ def setNamedAxesSelectorVisibility(self, visible):
+ """Show or hide the combo-boxes allowing to map the plot axes
+ to the data dimension.
+
+ :param visible: Boolean
+ """
+ self.__namedAxesVisibility = visible
+ for axis in self.__axis:
+ axis.setNamedAxisSelectorVisibility(visible)
diff --git a/silx/gui/data/RecordTableView.py b/silx/gui/data/RecordTableView.py
new file mode 100644
index 0000000..ce6a178
--- /dev/null
+++ b/silx/gui/data/RecordTableView.py
@@ -0,0 +1,405 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+import itertools
+import numpy
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+from silx.gui.widgets.TableWidget import CopySelectedCellsAction
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "27/01/2017"
+
+
+class _MultiLineItem(qt.QItemDelegate):
+ """Draw a multiline text without hiding anything.
+
+ The paint method display a cell without any wrap. And an editor is
+ available to scroll into the selected cell.
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent of the widget
+ """
+ qt.QItemDelegate.__init__(self, parent)
+ self.__textOptions = qt.QTextOption()
+ self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces |
+ qt.QTextOption.ShowTabsAndSpaces)
+ self.__textOptions.setWrapMode(qt.QTextOption.NoWrap)
+ self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft)
+
+ def paint(self, painter, option, index):
+ """
+ Write multiline text without using any wrap or any alignment according
+ to the cell size.
+
+ :param qt.QPainter painter: Painter context used to displayed the cell
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ painter.save()
+
+ # set colors
+ painter.setPen(qt.QPen(qt.Qt.NoPen))
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlight()
+ painter.setBrush(brush)
+ else:
+ brush = index.data(qt.Qt.BackgroundRole)
+ if brush is None:
+ # default background color for a cell
+ brush = qt.Qt.white
+ painter.setBrush(brush)
+ painter.drawRect(option.rect)
+
+ if index.isValid():
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlightedText()
+ else:
+ brush = index.data(qt.Qt.ForegroundRole)
+ if brush is None:
+ brush = option.palette.text()
+ painter.setPen(qt.QPen(brush.color()))
+ text = index.data(qt.Qt.DisplayRole)
+ painter.drawText(qt.QRectF(option.rect), text, self.__textOptions)
+
+ painter.restore()
+
+ def createEditor(self, parent, option, index):
+ """
+ Returns the widget used to edit the item specified by index for editing.
+
+ We use it not to edit the content but to show the content with a
+ convenient scroll bar.
+
+ :param qt.QWidget parent: Parent of the widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ if not index.isValid():
+ return super(_MultiLineItem, self).createEditor(parent, option, index)
+
+ editor = qt.QTextEdit(parent)
+ editor.setReadOnly(True)
+ return editor
+
+ def setEditorData(self, editor, index):
+ """
+ Read data from the model and feed the editor.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QIndex index: Index of the data to display
+ """
+ text = index.model().data(index, qt.Qt.EditRole)
+ editor.setText(text)
+
+ def updateEditorGeometry(self, editor, option, index):
+ """
+ Update the geometry of the editor according to the changes of the view.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ editor.setGeometry(option.rect)
+
+
+class RecordTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 1D slices from numpy array using
+ compound data types or hdf5 databases.
+
+ Each entries are displayed in a single row, and each columns contain a
+ specific field of the compound type.
+
+ It also allows to display 1D arrays of simple data types.
+ array.
+
+ :param qt.QObject parent: Parent object
+ :param numpy.ndarray data: A numpy array or a h5py dataset
+ """
+ def __init__(self, parent=None, data=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self.__data = None
+ self.__is_array = False
+ self.__fields = None
+ self.__formatter = None
+ self.__editFormatter = None
+ self.setFormatter(TextFormatter(self))
+
+ # set _data
+ self.setArrayData(data)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ if self.__data is None:
+ return 0
+ elif not self.__is_array:
+ return 1
+ else:
+ return len(self.__data)
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ if self.__fields is None:
+ return 1
+ else:
+ return len(self.__fields)
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ if self.__data is None:
+ return None
+
+ if self.__is_array:
+ if index.row() >= len(self.__data):
+ return None
+ data = self.__data[index.row()]
+ else:
+ if index.row() > 0:
+ return None
+ data = self.__data
+
+ if self.__fields is not None:
+ if index.column() >= len(self.__fields):
+ return None
+ key = self.__fields[index.column()][1]
+ data = data[key[0]]
+ if len(key) > 1:
+ data = data[key[1]]
+
+ if role == qt.Qt.DisplayRole:
+ return self.__formatter.toString(data)
+ elif role == qt.Qt.EditRole:
+ return self.__editFormatter.toString(data)
+ return None
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if section == -1:
+ # PyQt4 send -1 when there is columns but no rows
+ return None
+
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ if not self.__is_array:
+ return "Scalar"
+ else:
+ return str(section)
+ if orientation == qt.Qt.Horizontal:
+ if self.__fields is None:
+ if section == 0:
+ return "Data"
+ else:
+ return None
+ else:
+ if section < len(self.__fields):
+ return self.__fields[section][0]
+ else:
+ return None
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def setArrayData(self, data):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: 1D numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ """
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+
+ self.__data = data
+ if isinstance(data, numpy.ndarray):
+ self.__is_array = True
+ elif silx.io.is_dataset(data) and data.shape != tuple():
+ self.__is_array = True
+ else:
+ self.__is_array = False
+
+
+ self.__fields = []
+ if data is not None:
+ if data.dtype.fields is not None:
+ for name, (dtype, _index) in data.dtype.fields.items():
+ if dtype.shape != tuple():
+ keys = itertools.product(*[range(x) for x in dtype.shape])
+ for key in keys:
+ label = "%s%s" % (name, list(key))
+ array_key = (name, key)
+ self.__fields.append((label, array_key))
+ else:
+ self.__fields.append((name, (name,)))
+ else:
+ self.__fields = None
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+ else:
+ self.reset()
+
+ def arrayData(self):
+ """Returns the internal data.
+
+ :rtype: numpy.ndarray of h5py.Dataset
+ """
+ return self.__data
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ self.__editFormatter = TextFormatter(formatter)
+ self.__editFormatter.setUseQuoteForText(False)
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+ else:
+ self.reset()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.__editFormatter = TextFormatter(self, self.getFormatter())
+ self.__editFormatter.setUseQuoteForText(False)
+ self.reset()
+
+
+class _ShowEditorProxyModel(qt.QIdentityProxyModel):
+ """
+ Allow to custom the flag edit of the model
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QObject arent: parent object
+ """
+ super(_ShowEditorProxyModel, self).__init__(parent)
+ self.__forceEditable = False
+
+ def flags(self, index):
+ flag = qt.QIdentityProxyModel.flags(self, index)
+ if self.__forceEditable:
+ flag = flag | qt.Qt.ItemIsEditable
+ return flag
+
+ def forceCellEditor(self, show):
+ """
+ Enable the editable flag to allow to display cell editor.
+ """
+ if self.__forceEditable == show:
+ return
+ self.beginResetModel()
+ self.__forceEditable = show
+ self.endResetModel()
+
+
+class RecordTableView(qt.QTableView):
+ """TableView using DatabaseTableModel as default model.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: parent QWidget
+ """
+ qt.QTableView.__init__(self, parent)
+
+ model = _ShowEditorProxyModel(self)
+ model.setSourceModel(RecordTableModel())
+ self.setModel(model)
+ self.__multilineView = _MultiLineItem(self)
+ self.setEditTriggers(qt.QAbstractItemView.AllEditTriggers)
+ self._copyAction = CopySelectedCellsAction(self)
+ self.addAction(self._copyAction)
+
+ def copy(self):
+ self._copyAction.trigger()
+
+ def setArrayData(self, data):
+ self.model().sourceModel().setArrayData(data)
+ if data is not None:
+ if issubclass(data.dtype.type, (numpy.string_, numpy.unicode_)):
+ # TODO it would be nice to also fix fields
+ # but using it only for string array is already very useful
+ self.setItemDelegateForColumn(0, self.__multilineView)
+ self.model().forceCellEditor(True)
+ else:
+ self.setItemDelegateForColumn(0, None)
+ self.model().forceCellEditor(False)
diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py
new file mode 100644
index 0000000..f074de5
--- /dev/null
+++ b/silx/gui/data/TextFormatter.py
@@ -0,0 +1,222 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a class sharred by widget from the
+data module to format data as text in the same way."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+import numpy
+import numbers
+import binascii
+from silx.third_party import six
+from silx.gui import qt
+
+
+class TextFormatter(qt.QObject):
+ """Formatter to convert data to string.
+
+ The method :meth:`toString` returns a formatted string from an input data
+ using parameters set to this object.
+
+ It support most python and numpy data, expecting dictionary. Unsupported
+ data are displayed using the string representation of the object (`str`).
+
+ It provides a set of parameters to custom the formatting of integer and
+ float values (:meth:`setIntegerFormat`, :meth:`setFloatFormat`).
+
+ It also allows to custom the use of quotes to display text data
+ (:meth:`setUseQuoteForText`), and custom unit used to display imaginary
+ numbers (:meth:`setImaginaryUnit`).
+
+ The object emit an event `formatChanged` every time a parametter is
+ changed.
+ """
+
+ formatChanged = qt.Signal()
+ """Emitted when properties of the formatter change."""
+
+ def __init__(self, parent=None, formatter=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Owner of the object
+ :param TextFormatter formatter: Instantiate this object from the
+ formatter
+ """
+ qt.QObject.__init__(self, parent)
+ if formatter is not None:
+ self.__integerFormat = formatter.integerFormat()
+ self.__floatFormat = formatter.floatFormat()
+ self.__useQuoteForText = formatter.useQuoteForText()
+ self.__imaginaryUnit = formatter.imaginaryUnit()
+ else:
+ self.__integerFormat = "%d"
+ self.__floatFormat = "%g"
+ self.__useQuoteForText = True
+ self.__imaginaryUnit = u"j"
+
+ def integerFormat(self):
+ """Returns the format string controlling how the integer data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__integerFormat
+
+ def setIntegerFormat(self, value):
+ """Set format string controlling how the integer data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%d", "%i", "%08i").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__integerFormat == value:
+ return
+ self.__integerFormat = value
+ self.formatChanged.emit()
+
+ def floatFormat(self):
+ """Returns the format string controlling how the floating-point data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__floatFormat
+
+ def setFloatFormat(self, value):
+ """Set format string controlling how the floating-point data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%.3f", "%d", "%-10.2f",
+ "%10.3e").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__floatFormat == value:
+ return
+ self.__floatFormat = value
+ self.formatChanged.emit()
+
+ def useQuoteForText(self):
+ """Returns true if the string data are formatted using double quotes.
+
+ Else, no quotes are used.
+ """
+ return self.__integerFormat
+
+ def setUseQuoteForText(self, useQuote):
+ """Set the use of quotes to delimit string data.
+
+ :param bool useQuote: True to use quotes.
+ """
+ if self.__useQuoteForText == useQuote:
+ return
+ self.__useQuoteForText = useQuote
+ self.formatChanged.emit()
+
+ def imaginaryUnit(self):
+ """Returns the unit display for imaginary numbers.
+
+ :rtype: str
+ """
+ return self.__imaginaryUnit
+
+ def setImaginaryUnit(self, imaginaryUnit):
+ """Set the unit display for imaginary numbers.
+
+ :param str imaginaryUnit: Unit displayed after imaginary numbers
+ """
+ if self.__imaginaryUnit == imaginaryUnit:
+ return
+ self.__imaginaryUnit = imaginaryUnit
+ self.formatChanged.emit()
+
+ def toString(self, data):
+ """Format a data into a string using formatter options
+
+ :param object data: Data to render
+ :rtype: str
+ """
+ if isinstance(data, tuple):
+ text = [self.toString(d) for d in data]
+ return "(" + " ".join(text) + ")"
+ elif isinstance(data, (list, numpy.ndarray)):
+ text = [self.toString(d) for d in data]
+ return "[" + " ".join(text) + "]"
+ elif isinstance(data, numpy.void):
+ dtype = data.dtype
+ if data.dtype.fields is not None:
+ text = [self.toString(data[f]) for f in dtype.fields]
+ return "(" + " ".join(text) + ")"
+ return "0x" + binascii.hexlify(data).decode("ascii")
+ elif isinstance(data, (numpy.string_, numpy.object_, bytes)):
+ # This have to be done before checking python string inheritance
+ try:
+ text = "%s" % data.decode("utf-8")
+ if self.__useQuoteForText:
+ text = "\"%s\"" % text.replace("\"", "\\\"")
+ return text
+ except UnicodeDecodeError:
+ pass
+ return "0x" + binascii.hexlify(data).decode("ascii")
+ elif isinstance(data, six.string_types):
+ text = "%s" % data
+ if self.__useQuoteForText:
+ text = "\"%s\"" % text.replace("\"", "\\\"")
+ return text
+ elif isinstance(data, (numpy.integer, numbers.Integral)):
+ return self.__integerFormat % data
+ elif isinstance(data, (numbers.Real, numpy.floating)):
+ # It have to be done before complex checking
+ return self.__floatFormat % data
+ elif isinstance(data, (numpy.complex_, numbers.Complex)):
+ text = ""
+ if data.real != 0:
+ text += self.__floatFormat % data.real
+ if data.real != 0 and data.imag != 0:
+ if data.imag < 0:
+ template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, -data.imag)
+ else:
+ template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, data.imag)
+ else:
+ if data.imag != 0:
+ template = self.__floatFormat + self.__imaginaryUnit
+ params = (data.imag)
+ else:
+ template = self.__floatFormat
+ params = (data.real)
+ return template % params
+ return str(data)
diff --git a/silx/gui/data/__init__.py b/silx/gui/data/__init__.py
new file mode 100644
index 0000000..560062d
--- /dev/null
+++ b/silx/gui/data/__init__.py
@@ -0,0 +1,35 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for displaying data arrays using
+table views and plot widgets.
+
+.. note::
+
+ Widgets in this package may rely on additional dependencies that are
+ not mandatory for *silx*.
+ :class:`DataViewer.DataViewer` relies on :mod:`silx.gui.plot` which
+ depends on *matplotlib*. It also optionally depends on *PyOpenGL* for 3D
+ visualization.
+"""
diff --git a/silx/gui/data/setup.py b/silx/gui/data/setup.py
new file mode 100644
index 0000000..23ccbdd
--- /dev/null
+++ b/silx/gui/data/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('data', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/silx/gui/data/test/__init__.py b/silx/gui/data/test/__init__.py
new file mode 100644
index 0000000..08c044b
--- /dev/null
+++ b/silx/gui/data/test/__init__.py
@@ -0,0 +1,45 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from . import test_arraywidget
+from . import test_numpyaxesselector
+from . import test_dataviewer
+from . import test_textformatter
+
+__authors__ = ["V. Valls", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTests(
+ [test_arraywidget.suite(),
+ test_numpyaxesselector.suite(),
+ test_dataviewer.suite(),
+ test_textformatter.suite(),
+ ])
+ return test_suite
diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py
new file mode 100644
index 0000000..bbd7ee5
--- /dev/null
+++ b/silx/gui/data/test/test_arraywidget.py
@@ -0,0 +1,320 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import os
+import tempfile
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data import ArrayTableWidget
+from silx.gui.test.utils import TestCaseQt
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+class TestArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+
+ def tearDown(self):
+ del self.aw
+ super(TestArrayWidget, self).tearDown()
+
+ def testShow(self):
+ """test for errors"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testSetData0D(self):
+ a = 1
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # scalar/0D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData1D(self):
+ a = [1, 2]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 1D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData4D(self):
+ a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250),
+ (5, 5, 5, 10))
+ self.aw.setArrayData(a)
+
+ # default perspective (0, 1)
+ self.assertEqual(list(self.aw.model._perspective),
+ [0, 1])
+ self.aw.setPerspective((1, 3))
+ self.assertEqual(list(self.aw.model._perspective),
+ [1, 3])
+
+ b = self.aw.getData(copy=True)
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 4D data has a 2-tuple as frame index
+ self.assertEqual(len(self.aw.model._index), 2)
+ # default index is (0, 0)
+ self.assertEqual(list(self.aw.model._index),
+ [0, 0])
+ self.aw.setFrameIndex((3, 1))
+
+ self.assertEqual(list(self.aw.model._index),
+ [3, 1])
+
+ def testColors(self):
+ a = numpy.arange(256, dtype=numpy.uint8)
+ self.aw.setArrayData(a)
+
+ bgcolor = numpy.empty(a.shape + (3,), dtype=numpy.uint8)
+ # Black & white palette
+ bgcolor[..., 0] = a
+ bgcolor[..., 1] = a
+ bgcolor[..., 2] = a
+
+ fgcolor = numpy.bitwise_xor(bgcolor, 255)
+
+ self.aw.setArrayColors(bgcolor, fgcolor)
+
+ # test colors are as expected in model
+ for i in range(256):
+ # all RGB channels for BG equal to data value
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole),
+ qt.QColor(i, i, i),
+ "Unexpected background color"
+ )
+
+ # all RGB channels for FG equal to XOR(data value, 255)
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.ForegroundRole),
+ qt.QColor(i ^ 255, i ^ 255, i ^ 255),
+ "Unexpected text color"
+ )
+
+ # test colors are reset to None when a new data array is loaded
+ # with different shape
+ self.aw.setArrayData(numpy.arange(300))
+
+ for i in range(300):
+ # all RGB channels for BG equal to data value
+ self.assertIsNone(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole))
+
+ def testDefaultFlagNotEditable(self):
+ """editable should be False by default, in setArrayData"""
+ self.aw.setArrayData([[0]])
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagEditable(self):
+ self.aw.setArrayData([[0]], editable=True)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagNotEditable(self):
+ self.aw.setArrayData([[0]], editable=False)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+ """
+ # n-D (n >=2)
+ a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = numpy.linspace(0.213, 1.234, 1000)
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+
+@unittest.skipIf(h5py is None, "Could not import h5py")
+class TestH5pyArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a dataset.
+
+ Test flags, for dataset open in read-only or read-write modes"""
+ def setUp(self):
+ super(TestH5pyArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+ self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ # create an h5py file with a dataset
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "array.h5")
+ h5f = h5py.File(self.h5_fname)
+ h5f["my_array"] = self.data
+ h5f["my_scalar"] = 3.14
+ h5f["my_1D_array"] = numpy.array(numpy.arange(1000))
+ h5f.close()
+
+ def tearDown(self):
+ del self.aw
+ os.unlink(self.h5_fname)
+ os.rmdir(self.tempdir)
+ super(TestH5pyArrayWidget, self).tearDown()
+
+ def testShow(self):
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testReadOnly(self):
+ """Open H5 dataset in read-only mode, ensure the model is not editable."""
+ h5f = h5py.File(self.h5_fname, "r")
+ a = h5f["my_array"]
+ # ArrayTableModel relies on following condition
+ self.assertTrue(a.file.mode == "r")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+
+ self.assertIsInstance(a, h5py.Dataset) # simple sanity check
+ # internal representation must be a reference to original data (copy=False)
+ self.assertIsInstance(self.aw.model._array, h5py.Dataset)
+ self.assertTrue(self.aw.model._array.file.mode == "r")
+
+ b = self.aw.getData()
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ # model must have detected read-only dataset and disabled editing
+ self.assertFalse(self.aw.model._editable)
+ idx = self.aw.model.createIndex(0, 0)
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ # force editing read-only datasets raises IOError
+ self.assertRaises(IOError, self.aw.model.setData,
+ idx, 123.4, role=qt.Qt.EditRole)
+ h5f.close()
+
+ def testReadWrite(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_array"]
+ self.assertTrue(a.file.mode == "r+")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+ b = self.aw.getData(copy=False)
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ h5f.close()
+
+ def testSetData0D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_scalar"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testSetData1D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_1D_array"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+
+ This only works for array with at least 2D. For 1D and 0D
+ arrays, a view is created at some point, which in the case
+ of an hdf5 dataset creates a copy."""
+ h5f = h5py.File(self.h5_fname, "r+")
+
+ # n-D
+ a0 = h5f["my_array"]
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = h5f["my_1D_array"]
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+ h5f.close()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestArrayWidget))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestH5pyArrayWidget))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py
new file mode 100644
index 0000000..5a0de0b
--- /dev/null
+++ b/silx/gui/data/test/test_dataviewer.py
@@ -0,0 +1,281 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "10/04/2017"
+
+import os
+import tempfile
+import unittest
+from contextlib import contextmanager
+
+import numpy
+from ..DataViewer import DataViewer
+from ..DataViews import DataView
+from .. import DataViews
+
+from silx.gui import qt
+
+from silx.gui.data.DataViewerFrame import DataViewerFrame
+from silx.gui.test.utils import SignalListener
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui.hdf5.test import _mock
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+class _DataViewMock(DataView):
+ """Dummy view to display nothing"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def axesNames(self, data, info):
+ return []
+
+ def createWidget(self, parent):
+ return qt.QLabel(parent)
+
+ def getDataPriority(self, data, info):
+ return 0
+
+
+class AbstractDataViewerTests(TestCaseQt):
+
+ def create_widget(self):
+ raise NotImplementedError()
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_text_data(self):
+ data_list = ["aaa", int, 8, self]
+ widget = self.create_widget()
+ for data in data_list:
+ widget.setData(data)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayMode())
+
+ def test_plot_1d_data(self):
+ data = numpy.arange(3 ** 1)
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViewer.PLOT1D_MODE, availableModes)
+
+ def test_plot_2d_data(self):
+ data = numpy.arange(3 ** 2)
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViewer.PLOT2D_MODE, availableModes)
+
+ def test_plot_3d_data(self):
+ data = numpy.arange(3 ** 3)
+ data.shape = [3] * 3
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ try:
+ import silx.gui.plot3d # noqa
+ self.assertIn(DataViewer.PLOT3D_MODE, availableModes)
+ except ImportError:
+ self.assertIn(DataViewer.STACK_MODE, availableModes)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayMode())
+
+ def test_array_1d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 1))
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_2d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 2))
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_4d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 4))
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId())
+
+ def test_record_4d_data(self):
+ data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64')
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViewer.RAW_MODE, widget.displayedView().modeId())
+
+ def test_3d_h5_dataset(self):
+ if h5py is None:
+ self.skipTest("h5py library is not available")
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ widget = self.create_widget()
+ widget.setData(dataset)
+
+ def test_data_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.dataChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ self.assertEquals(listener.callCount(), 2)
+
+ def test_display_mode_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.displayedViewChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ modes = [v.modeId() for v in listener.arguments(argumentIndex=0)]
+ self.assertEquals(modes, [DataViewer.RAW_MODE, DataViewer.EMPTY_MODE])
+ listener.clear()
+
+ def test_change_display_mode(self):
+ data = numpy.arange(10 ** 4)
+ data.shape = [10] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ widget.setDisplayMode(DataViewer.PLOT1D_MODE)
+ self.assertEquals(widget.displayedView().modeId(), DataViewer.PLOT1D_MODE)
+ widget.setDisplayMode(DataViewer.PLOT2D_MODE)
+ self.assertEquals(widget.displayedView().modeId(), DataViewer.PLOT2D_MODE)
+ widget.setDisplayMode(DataViewer.RAW_MODE)
+ self.assertEquals(widget.displayedView().modeId(), DataViewer.RAW_MODE)
+ widget.setDisplayMode(DataViewer.EMPTY_MODE)
+ self.assertEquals(widget.displayedView().modeId(), DataViewer.EMPTY_MODE)
+
+ def test_create_default_views(self):
+ widget = self.create_widget()
+ views = widget.createDefaultViews()
+ self.assertTrue(len(views) > 0)
+
+ def test_add_view(self):
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ widget.addView(view)
+ self.assertTrue(view in widget.availableViews())
+ self.assertTrue(view in widget.currentAvailableViews())
+
+ def test_remove_view(self):
+ widget = self.create_widget()
+ widget.setData("foobar")
+ view = widget.currentAvailableViews()[0]
+ widget.removeView(view)
+ self.assertTrue(view not in widget.availableViews())
+ self.assertTrue(view not in widget.currentAvailableViews())
+
+class TestDataViewer(AbstractDataViewerTests):
+ def create_widget(self):
+ return DataViewer()
+
+
+class TestDataViewerFrame(AbstractDataViewerTests):
+ def create_widget(self):
+ return DataViewerFrame()
+
+
+class TestDataView(TestCaseQt):
+
+ def createComplexData(self):
+ line = [1, 2j, 3+3j, 4]
+ image = [line, line, line, line]
+ cube = [image, image, image, image]
+ data = numpy.array(cube,
+ dtype=numpy.complex)
+ return data
+
+ def createDataViewWithData(self, dataViewClass, data):
+ viewer = dataViewClass(None)
+ widget = viewer.getWidget()
+ viewer.setData(data)
+ return widget
+
+ def testCurveWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot1dView
+ widget = self.createDataViewWithData(dataViewClass, data[0, 0])
+ self.qWaitForWindowExposed(widget)
+
+ def testImageWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot2dView
+ widget = self.createDataViewWithData(dataViewClass, data[0])
+ self.qWaitForWindowExposed(widget)
+
+ def testCubeWithComplex(self):
+ self.skipTest("OpenGL widget not yet tested")
+ try:
+ import silx.gui.plot3d # noqa
+ except ImportError:
+ self.skipTest("OpenGL not available")
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot3dView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
+
+ def testImageStackWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._StackView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase
+ test_suite.addTest(loadTestsFromTestCase(TestDataViewer))
+ test_suite.addTest(loadTestsFromTestCase(TestDataViewerFrame))
+ test_suite.addTest(loadTestsFromTestCase(TestDataView))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_numpyaxesselector.py b/silx/gui/data/test/test_numpyaxesselector.py
new file mode 100644
index 0000000..cc15f83
--- /dev/null
+++ b/silx/gui/data/test/test_numpyaxesselector.py
@@ -0,0 +1,152 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "15/12/2016"
+
+import os
+import tempfile
+import unittest
+from contextlib import contextmanager
+
+import numpy
+
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.test.utils import SignalListener
+from silx.gui.test.utils import TestCaseQt
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+class TestNumpyAxesSelector(TestCaseQt):
+
+ def test_creation(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ widget.setVisible(True)
+
+ def test_none(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ widget.setData(None)
+ result = widget.selectedData()
+ self.assertIsNone(result)
+
+ def test_output_samedim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["x", "y", "z"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_lessdim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0]
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["y", "x"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_1dim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0, 0, 0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_h5py_dataset(self):
+ if h5py is None:
+ self.skipTest("h5py library is not available")
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ expectedResult = dataset[0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(dataset)
+ widget.setAxisNames(["y", "x"])
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.dataChanged.connect(listener)
+ widget.setData(data)
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 2)
+
+ def test_selected_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.selectionChanged.connect(listener)
+ widget.setData(data)
+ widget.setAxisNames(["x"])
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 3)
+ listener.clear()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestNumpyAxesSelector))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py
new file mode 100644
index 0000000..f21e033
--- /dev/null
+++ b/silx/gui/data/test/test_textformatter.py
@@ -0,0 +1,94 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+import unittest
+
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.test.utils import SignalListener
+from ..TextFormatter import TextFormatter
+
+
+class TestTextFormatter(TestCaseQt):
+
+ def test_copy(self):
+ formatter = TextFormatter()
+ copy = TextFormatter(formatter=formatter)
+ self.assertIsNot(formatter, copy)
+ copy.setFloatFormat("%.3f")
+ self.assertEquals(formatter.integerFormat(), copy.integerFormat())
+ self.assertNotEquals(formatter.floatFormat(), copy.floatFormat())
+ self.assertEquals(formatter.useQuoteForText(), copy.useQuoteForText())
+ self.assertEquals(formatter.imaginaryUnit(), copy.imaginaryUnit())
+
+ def test_event(self):
+ listener = SignalListener()
+ formatter = TextFormatter()
+ formatter.formatChanged.connect(listener)
+ formatter.setFloatFormat("%.3f")
+ formatter.setIntegerFormat("%03i")
+ formatter.setUseQuoteForText(False)
+ formatter.setImaginaryUnit("z")
+ self.assertEquals(listener.callCount(), 4)
+
+ def test_int(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%05i")
+ result = formatter.toString(512)
+ self.assertEquals(result, "00512")
+
+ def test_float(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.3f")
+ result = formatter.toString(1.3)
+ self.assertEquals(result, "1.300")
+
+ def test_complex(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.1f")
+ formatter.setImaginaryUnit("i")
+ result = formatter.toString(1.0 + 5j)
+ result = result.replace(" ", "")
+ self.assertEquals(result, "1.0+5.0i")
+
+ def test_string(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%.1f")
+ formatter.setImaginaryUnit("z")
+ result = formatter.toString("toto")
+ self.assertEquals(result, '"toto"')
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestTextFormatter))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/fit/BackgroundWidget.py b/silx/gui/fit/BackgroundWidget.py
new file mode 100644
index 0000000..577a8c7
--- /dev/null
+++ b/silx/gui/fit/BackgroundWidget.py
@@ -0,0 +1,530 @@
+# coding: utf-8
+#/*##########################################################################
+# Copyright (C) 2004-2017 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# #########################################################################*/
+"""This module provides a background configuration widget
+:class:`BackgroundWidget` and a corresponding dialog window
+:class:`BackgroundDialog`."""
+import sys
+import numpy
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.math.fit import filters
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+class HorizontalSpacer(qt.QWidget):
+ def __init__(self, *args):
+ qt.QWidget.__init__(self, *args)
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Fixed))
+
+
+class BackgroundParamWidget(qt.QWidget):
+ """Background configuration composite widget.
+
+ Strip and snip filters parameters can be adjusted using input widgets.
+
+ Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to
+ be emitted.
+ """
+ sigBackgroundParamWidgetSignal = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.mainLayout = qt.QGridLayout(self)
+ self.mainLayout.setColumnStretch(1, 1)
+
+ # Algorithm choice ---------------------------------------------------
+ self.algorithmComboLabel = qt.QLabel(self)
+ self.algorithmComboLabel.setText("Background algorithm")
+ self.algorithmCombo = qt.QComboBox(self)
+ self.algorithmCombo.addItem("Strip")
+ self.algorithmCombo.addItem("Snip")
+ self.algorithmCombo.activated[int].connect(
+ self._algorithmComboActivated)
+
+ # Strip parameters ---------------------------------------------------
+ self.stripWidthLabel = qt.QLabel(self)
+ self.stripWidthLabel.setText("Strip Width")
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setMaximum(100)
+ self.stripWidthSpin.setMinimum(1)
+ self.stripWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+ self.stripIterLabel = qt.QLabel(self)
+ self.stripIterLabel.setText("Strip Iterations")
+ self.stripIterValue = qt.QLineEdit(self)
+ validator = qt.QIntValidator(self.stripIterValue)
+ self.stripIterValue._v = validator
+ self.stripIterValue.setText("0")
+ self.stripIterValue.editingFinished[()].connect(self._emitSignal)
+ self.stripIterValue.setToolTip(
+ "Number of iterations for strip algorithm.\n" +
+ "If greater than 999, an 2nd pass of strip filter is " +
+ "applied to remove artifacts created by first pass.")
+
+ # Snip parameters ----------------------------------------------------
+ self.snipWidthLabel = qt.QLabel(self)
+ self.snipWidthLabel.setText("Snip Width")
+
+ self.snipWidthSpin = qt.QSpinBox(self)
+ self.snipWidthSpin.setMaximum(300)
+ self.snipWidthSpin.setMinimum(0)
+ self.snipWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+
+ # Smoothing parameters -----------------------------------------------
+ self.smoothingFlagCheck = qt.QCheckBox(self)
+ self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)")
+ self.smoothingFlagCheck.toggled.connect(self._smoothingToggled)
+
+ self.smoothingSpin = qt.QSpinBox(self)
+ self.smoothingSpin.setMinimum(3)
+ #self.smoothingSpin.setMaximum(40)
+ self.smoothingSpin.setSingleStep(2)
+ self.smoothingSpin.valueChanged[int].connect(self._emitSignal)
+
+ # Anchors ------------------------------------------------------------
+
+ self.anchorsGroup = qt.QWidget(self)
+ anchorsLayout = qt.QHBoxLayout(self.anchorsGroup)
+ anchorsLayout.setSpacing(2)
+ anchorsLayout.setContentsMargins(0, 0, 0, 0)
+
+ self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup)
+ self.anchorsFlagCheck.setText("Use anchors")
+ self.anchorsFlagCheck.setToolTip(
+ "Define X coordinates of points that must remain fixed")
+ self.anchorsFlagCheck.stateChanged[int].connect(
+ self._anchorsToggled)
+ anchorsLayout.addWidget(self.anchorsFlagCheck)
+
+ maxnchannel = 16384 * 4 # Fixme ?
+ self.anchorsList = []
+ num_anchors = 4
+ for i in range(num_anchors):
+ anchorSpin = qt.QSpinBox(self.anchorsGroup)
+ anchorSpin.setMinimum(0)
+ anchorSpin.setMaximum(maxnchannel)
+ anchorSpin.valueChanged[int].connect(self._emitSignal)
+ anchorsLayout.addWidget(anchorSpin)
+ self.anchorsList.append(anchorSpin)
+
+ # Layout ------------------------------------------------------------
+ self.mainLayout.addWidget(self.algorithmComboLabel, 0, 0)
+ self.mainLayout.addWidget(self.algorithmCombo, 0, 2)
+ self.mainLayout.addWidget(self.stripWidthLabel, 1, 0)
+ self.mainLayout.addWidget(self.stripWidthSpin, 1, 2)
+ self.mainLayout.addWidget(self.stripIterLabel, 2, 0)
+ self.mainLayout.addWidget(self.stripIterValue, 2, 2)
+ self.mainLayout.addWidget(self.snipWidthLabel, 3, 0)
+ self.mainLayout.addWidget(self.snipWidthSpin, 3, 2)
+ self.mainLayout.addWidget(self.smoothingFlagCheck, 4, 0)
+ self.mainLayout.addWidget(self.smoothingSpin, 4, 2)
+ self.mainLayout.addWidget(self.anchorsGroup, 5, 0, 1, 4)
+
+ # Initialize interface -----------------------------------------------
+ self._setAlgorithm("strip")
+ self.smoothingFlagCheck.setChecked(False)
+ self._smoothingToggled(is_checked=False)
+ self.anchorsFlagCheck.setChecked(False)
+ self._anchorsToggled(is_checked=False)
+
+ def _algorithmComboActivated(self, algorithm_index):
+ self._setAlgorithm("strip" if algorithm_index == 0 else "snip")
+
+ def _setAlgorithm(self, algorithm):
+ """Enable/disable snip and snip input widgets, depending on the
+ chosen algorithm.
+ :param algorithm: "snip" or "strip"
+ """
+ if algorithm not in ["strip", "snip"]:
+ raise ValueError(
+ "Unknown background filter algorithm %s" % algorithm)
+
+ self.algorithm = algorithm
+ self.stripWidthSpin.setEnabled(algorithm == "strip")
+ self.stripIterValue.setEnabled(algorithm == "strip")
+ self.snipWidthSpin.setEnabled(algorithm == "snip")
+
+ def _smoothingToggled(self, is_checked):
+ """Enable/disable smoothing input widgets, emit dictionary"""
+ self.smoothingSpin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def _anchorsToggled(self, is_checked):
+ """Enable/disable all spin widgets defining anchor X coordinates,
+ emit signal.
+ """
+ for anchor_spin in self.anchorsList:
+ anchor_spin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ if "algorithm" in ddict:
+ self._setAlgorithm(ddict["algorithm"])
+
+ if "SnipWidth" in ddict:
+ self.snipWidthSpin.setValue(int(ddict["SnipWidth"]))
+
+ if "StripWidth" in ddict:
+ self.stripWidthSpin.setValue(int(ddict["StripWidth"]))
+
+ if "StripIterations" in ddict:
+ self.stripIterValue.setText("%d" % int(ddict["StripIterations"]))
+
+ if "SmoothingFlag" in ddict:
+ self.smoothingFlagCheck.setChecked(bool(ddict["SmoothingFlag"]))
+
+ if "SmoothingWidth" in ddict:
+ self.smoothingSpin.setValue(int(ddict["SmoothingWidth"]))
+
+ if "AnchorsFlag" in ddict:
+ self.anchorsFlagCheck.setChecked(bool(ddict["AnchorsFlag"]))
+
+ if "AnchorsList" in ddict:
+ anchorslist = ddict["AnchorsList"]
+ if anchorslist in [None, 'None']:
+ anchorslist = []
+ for spin in self.anchorsList:
+ spin.setValue(0)
+
+ i = 0
+ for value in anchorslist:
+ self.anchorsList[i].setValue(int(value))
+ i += 1
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: curvature parameter (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ stripitertext = self.stripIterValue.text()
+ stripiter = int(stripitertext) if len(stripitertext) else 0
+
+ return {"algorithm": self.algorithm,
+ "StripThreshold": 1.0,
+ "SnipWidth": self.snipWidthSpin.value(),
+ "StripIterations": stripiter,
+ "StripWidth": self.stripWidthSpin.value(),
+ "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
+ "SmoothingWidth": self.smoothingSpin.value(),
+ "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
+ "AnchorsList": [spin.value() for spin in self.anchorsList]}
+
+ def _emitSignal(self, dummy=None):
+ self.sigBackgroundParamWidgetSignal.emit(
+ {'event': 'ParametersChanged',
+ 'parameters': self.getParameters()})
+
+
+class BackgroundWidget(qt.QWidget):
+ """Background configuration widget, with a :class:`PlotWindow`.
+
+ Strip and snip filters parameters can be adjusted using input widgets,
+ and the computed backgrounds are plotted next to the original data to
+ show the result."""
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle("Strip and SNIP Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundParamWidget(self)
+ self.graphWidget = PlotWidget(parent=self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ self.mainLayout.addWidget(self.graphWidget)
+ self._x = None
+ self._y = None
+ self.parametersWidget.sigBackgroundParamWidgetSignal.connect(self._slot)
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: strip curvature (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ return self.parametersWidget.setParameters(ddict)
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """Set data for the original curve, and _update strip and snip
+ curves accordingly.
+
+ :param x: Array or sequence of curve abscissa values
+ :param y: Array or sequence of curve ordinate values
+ :param xmin: Min value to be displayed on the X axis
+ :param xmax: Max value to be displayed on the X axis
+ """
+ self._x = x
+ self._y = y
+ self._xmin = xmin
+ self._xmax = xmax
+ self._update(resetzoom=True)
+
+ def _slot(self, ddict):
+ self._update()
+
+ def _update(self, resetzoom=False):
+ """Compute strip and snip backgrounds, update the curves
+ """
+ if self._y is None:
+ return
+
+ pars = self.getParameters()
+
+ # smoothed data
+ y = numpy.ravel(numpy.array(self._y)).astype(numpy.float)
+ if pars["SmoothingFlag"]:
+ ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
+ f = [0.25, 0.5, 0.25]
+ ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0)
+ ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1])
+ ysmooth[-1] = 0.5 * (ysmooth[-1] + ysmooth[-2])
+ else:
+ ysmooth = y
+
+
+ # loop for anchors
+ x = self._x
+ niter = pars['StripIterations']
+ anchors_indices = []
+ if pars['AnchorsFlag'] and pars['AnchorsList'] is not None:
+ ravelled = x
+ for channel in pars['AnchorsList']:
+ if channel <= ravelled[0]:
+ continue
+ index = numpy.nonzero(ravelled >= channel)[0]
+ if len(index):
+ index = min(index)
+ if index > 0:
+ anchors_indices.append(index)
+
+ stripBackground = filters.strip(ysmooth,
+ w=pars['StripWidth'],
+ niterations=niter,
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if niter >= 1000:
+ # final smoothing
+ stripBackground = filters.strip(stripBackground,
+ w=1,
+ niterations=50*pars['StripWidth'],
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if len(anchors_indices) == 0:
+ anchors_indices = [0, len(ysmooth)-1]
+ anchors_indices.sort()
+ snipBackground = 0.0 * ysmooth
+ lastAnchor = 0
+ for anchor in anchors_indices:
+ if (anchor > lastAnchor) and (anchor < len(ysmooth)):
+ snipBackground[lastAnchor:anchor] =\
+ filters.snip1d(ysmooth[lastAnchor:anchor],
+ pars['SnipWidth'])
+ lastAnchor = anchor
+ if lastAnchor < len(ysmooth):
+ snipBackground[lastAnchor:] =\
+ filters.snip1d(ysmooth[lastAnchor:],
+ pars['SnipWidth'])
+
+ self.graphWidget.addCurve(x, y,
+ legend='Input Data',
+ replace=True,
+ resetzoom=resetzoom)
+ self.graphWidget.addCurve(x, stripBackground,
+ legend='Strip Background',
+ resetzoom=False)
+ self.graphWidget.addCurve(x, snipBackground,
+ legend='SNIP Background',
+ resetzoom=False)
+ if self._xmin is not None and self._xmax is not None:
+ self.graphWidget.setGraphXLimits(xmin=self._xmin, xmax=self._xmax)
+
+
+class BackgroundDialog(qt.QDialog):
+ """QDialog window featuring a :class:`BackgroundWidget`"""
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Strip and Snip Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundWidget(self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ self.okButton = qt.QPushButton(hbox)
+ self.okButton.setText("OK")
+ self.okButton.setAutoDefault(False)
+ self.dismissButton = qt.QPushButton(hbox)
+ self.dismissButton.setText("Cancel")
+ self.dismissButton.setAutoDefault(False)
+ hboxLayout.addWidget(HorizontalSpacer(hbox))
+ hboxLayout.addWidget(self.okButton)
+ hboxLayout.addWidget(self.dismissButton)
+ self.mainLayout.addWidget(hbox)
+ self.dismissButton.clicked.connect(self.reject)
+ self.okButton.clicked.connect(self.accept)
+
+ self.output = {}
+ """Configuration dictionary containing following fields:
+
+ - *SmoothingFlag*
+ - *SmoothingWidth*
+ - *StripWidth*
+ - *StripIterations*
+ - *StripThreshold*
+ - *SnipWidth*
+ - *AnchorsFlag*
+ - *AnchorsList*
+ """
+
+ # self.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(self.updateOutput)
+
+ # def updateOutput(self, ddict):
+ # self.output = ddict
+
+ def accept(self):
+ """Update :attr:`output`, then call :meth:`QDialog.accept`
+ """
+ self.output = self.getParameters()
+ super(BackgroundDialog, self).accept()
+
+ def sizeHint(self):
+ return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()),
+ qt.QDialog.sizeHint(self).height())
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """See :meth:`BackgroundWidget.setData`"""
+ return self.parametersWidget.setData(x, y, xmin, xmax)
+
+ def getParameters(self):
+ """See :meth:`BackgroundWidget.getParameters`"""
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """See :meth:`BackgroundWidget.setParameters`"""
+ return self.parametersWidget.setParameters(ddict)
+
+ def setDefault(self, ddict):
+ """Alias for :meth:`setParameters`"""
+ return self.setParameters(ddict)
+
+
+def getBgDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a bg configuration dialog, adapted
+ for configuring standard background theories from
+ :mod:`silx.math.fit.bgtheories`.
+
+ :return: Instance of :class:`BackgroundDialog`
+ """
+ bgd = BackgroundDialog(parent=parent)
+ # apply default to newly added pages
+ bgd.setParameters(default)
+
+ return bgd
+
+
+def main():
+ # synthetic data
+ from silx.math.fit.functions import sum_gauss
+
+ x = numpy.arange(5000)
+ # (height1, center1, fwhm1, ...) 5 peaks
+ params1 = (50, 500, 100,
+ 20, 2000, 200,
+ 50, 2250, 100,
+ 40, 3000, 75,
+ 23, 4000, 150)
+ y0 = sum_gauss(x, *params1)
+
+ # random values between [-1;1]
+ noise = 2 * numpy.random.random(5000) - 1
+ # make it +- 5%
+ noise *= 0.05
+
+ # 2 gaussians with very large fwhm, as background signal
+ actual_bg = sum_gauss(x, 15, 3500, 3000, 5, 1000, 1500)
+
+ # Add 5% random noise to gaussians and add background
+ y = y0 + numpy.average(y0) * noise + actual_bg
+
+ # Open widget
+ a = qt.QApplication(sys.argv)
+ a.lastWindowClosed.connect(a.quit)
+
+ def mySlot(ddict):
+ print(ddict)
+
+ w = BackgroundDialog()
+ w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot)
+ w.setData(x, y)
+ w.exec_()
+ #a.exec_()
+
+if __name__ == "__main__":
+ main()
diff --git a/silx/gui/fit/FitConfig.py b/silx/gui/fit/FitConfig.py
new file mode 100644
index 0000000..70b6fbe
--- /dev/null
+++ b/silx/gui/fit/FitConfig.py
@@ -0,0 +1,540 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2016 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines widgets used to build a fit configuration dialog.
+The resulting dialog widget outputs a dictionary of configuration parameters.
+"""
+from silx.gui import qt
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class TabsDialog(qt.QDialog):
+ """Dialog widget containing a QTabWidget :attr:`tabWidget`
+ and a buttons:
+
+ # - buttonHelp
+ - buttonDefaults
+ - buttonOk
+ - buttonCancel
+
+ This dialog defines a __len__ returning the number of tabs,
+ and an __iter__ method yielding the tab widgets.
+ """
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.tabWidget = qt.QTabWidget(self)
+
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.tabWidget)
+
+ layout2 = qt.QHBoxLayout(None)
+
+ # self.buttonHelp = qt.QPushButton(self)
+ # self.buttonHelp.setText("Help")
+ # layout2.addWidget(self.buttonHelp)
+
+ self.buttonDefault = qt.QPushButton(self)
+ self.buttonDefault.setText("Default")
+ layout2.addWidget(self.buttonDefault)
+
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout2.addItem(spacer)
+
+ self.buttonOk = qt.QPushButton(self)
+ self.buttonOk.setText("OK")
+ layout2.addWidget(self.buttonOk)
+
+ self.buttonCancel = qt.QPushButton(self)
+ self.buttonCancel.setText("Cancel")
+ layout2.addWidget(self.buttonCancel)
+
+ layout.addLayout(layout2)
+
+ self.buttonOk.clicked.connect(self.accept)
+ self.buttonCancel.clicked.connect(self.reject)
+
+ def __len__(self):
+ """Return number of tabs"""
+ return self.tabWidget.count()
+
+ def __iter__(self):
+ """Return the next tab widget in :attr:`tabWidget` every
+ time this method is called.
+
+ :return: Tab widget
+ :rtype: QWidget
+ """
+ for widget_index in range(len(self)):
+ yield self.tabWidget.widget(widget_index)
+
+ def addTab(self, page, label):
+ """Add a new tab
+
+ :param page: Content of new page. Must be a widget with
+ a get() method returning a dictionary.
+ :param str label: Tab label
+ """
+ self.tabWidget.addTab(page, label)
+
+ def getTabLabels(self):
+ """
+ Return a list of all tab labels in :attr:`tabWidget`
+ """
+ return [self.tabWidget.tabText(i) for i in range(len(self))]
+
+
+class TabsDialogData(TabsDialog):
+ """This dialog adds a data attribute to :class:`TabsDialog`.
+
+ Data input in widgets, such as text entries or checkboxes, is stored in an
+ attribute :attr:`output` when the user clicks the OK button.
+
+ A default dictionary can be supplied when this dialog is initialized, to
+ be used as default data for :attr:`output`.
+ """
+ def __init__(self, parent=None, modal=True, default=None):
+ """
+
+ :param parent: Parent :class:`QWidget`
+ :param modal: If `True`, dialog is modal, meaning this dialog remains
+ in front of it's parent window and disables it until the user is
+ done interacting with the dialog
+ :param default: Default dictionary, used to initialize and reset
+ :attr:`output`.
+ """
+ TabsDialog.__init__(self, parent)
+ self.setModal(modal)
+ self.setWindowTitle("Fit configuration")
+
+ self.output = {}
+
+ self.default = {} if default is None else default
+
+ self.buttonDefault.clicked.connect(self.setDefault)
+ # self.keyPressEvent(qt.Qt.Key_Enter).
+
+ def keyPressEvent(self, event):
+ """Redefining this method to ignore Enter key
+ (for some reason it activates buttonDefault callback which
+ resets all widgets)
+ """
+ if event.key() in [qt.Qt.Key_Enter, qt.Qt.Key_Return]:
+ return
+ TabsDialog.keyPressEvent(self, event)
+
+ def accept(self):
+ """When *OK* is clicked, update :attr:`output` with data from
+ various widgets
+ """
+ self.output.update(self.default)
+
+ # loop over all tab widgets (uses TabsDialog.__iter__)
+ for tabWidget in self:
+ self.output.update(tabWidget.get())
+
+ # avoid pathological None cases
+ for key in self.output.keys():
+ if self.output[key] is None:
+ if key in self.default:
+ self.output[key] = self.default[key]
+ super(TabsDialogData, self).accept()
+
+ def reject(self):
+ """When the *Cancel* button is clicked, reinitialize :attr:`output`
+ and quit
+ """
+ self.setDefault()
+ super(TabsDialogData, self).reject()
+
+ def setDefault(self, newdefault=None):
+ """Reinitialize :attr:`output` with :attr:`default` or with
+ new dictionary ``newdefault`` if provided.
+ Call :meth:`setDefault` for each tab widget, if available.
+ """
+ self.output = {}
+ if newdefault is None:
+ newdefault = self.default
+ else:
+ self.default = newdefault
+ self.output.update(newdefault)
+
+ for tabWidget in self:
+ if hasattr(tabWidget, "setDefault"):
+ tabWidget.setDefault(self.output)
+
+
+class ConstraintsPage(qt.QGroupBox):
+ """Checkable QGroupBox widget filled with QCheckBox widgets,
+ to configure the fit estimation for standard fit theories.
+ """
+ def __init__(self, parent=None, title="Set constraints"):
+ super(ConstraintsPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setToolTip("Disable 'Set constraints' to remove all " +
+ "constraints on all fit parameters")
+ self.setCheckable(True)
+
+ layout = qt.QVBoxLayout(self)
+ self.setLayout(layout)
+
+ self.positiveHeightCB = qt.QCheckBox("Force positive height/area", self)
+ self.positiveHeightCB.setToolTip("Fit must find positive peaks")
+ layout.addWidget(self.positiveHeightCB)
+
+ self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self)
+ self.positionInIntervalCB.setToolTip(
+ "Fit must position peak within X limits")
+ layout.addWidget(self.positionInIntervalCB)
+
+ self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self)
+ self.positiveFwhmCB.setToolTip("Fit must find a positive FWHM")
+ layout.addWidget(self.positiveFwhmCB)
+
+ self.sameFwhmCB = qt.QCheckBox("Force same FWHM for all peaks", self)
+ self.sameFwhmCB.setToolTip("Fit must find same FWHM for all peaks")
+ layout.addWidget(self.sameFwhmCB)
+
+ self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self)
+ self.quotedEtaCB.setToolTip(
+ "Fit must find Eta between 0 and 1 for pseudo-Voigt function")
+ layout.addWidget(self.quotedEtaCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default state for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default state."""
+ if default_dict is None:
+ default_dict = {}
+ # this one uses reverse logic: if checked, NoConstraintsFlag must be False
+ self.setChecked(
+ not default_dict.get('NoConstraintsFlag', False))
+ self.positiveHeightCB.setChecked(
+ default_dict.get('PositiveHeightAreaFlag', True))
+ self.positionInIntervalCB.setChecked(
+ default_dict.get('QuotedPositionFlag', False))
+ self.positiveFwhmCB.setChecked(
+ default_dict.get('PositiveFwhmFlag', True))
+ self.sameFwhmCB.setChecked(
+ default_dict.get('SameFwhmFlag', False))
+ self.quotedEtaCB.setChecked(
+ default_dict.get('QuotedEtaFlag', False))
+
+ def get(self):
+ """Return a dictionary of constraint flags, to be processed by the
+ :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'NoConstraintsFlag': not self.isChecked(),
+ 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(),
+ 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(),
+ 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(),
+ 'SameFwhmFlag': self.sameFwhmCB.isChecked(),
+ 'QuotedEtaFlag': self.quotedEtaCB.isChecked(),
+ }
+ return ddict
+
+
+class SearchPage(qt.QWidget):
+ def __init__(self, parent=None):
+ super(SearchPage, self).__init__(parent)
+ layout = qt.QVBoxLayout(self)
+
+ self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self)
+ self.manualFwhmGB.setCheckable(True)
+ self.manualFwhmGB.setToolTip(
+ "If disabled, the FWHM parameter used for peak search is " +
+ "estimated based on the highest peak in the data")
+ layout.addWidget(self.manualFwhmGB)
+ # ------------ GroupBox fwhm--------------------------
+ layout2 = qt.QHBoxLayout(self.manualFwhmGB)
+ self.manualFwhmGB.setLayout(layout2)
+
+ label = qt.QLabel("Fwhm Points", self.manualFwhmGB)
+ layout2.addWidget(label)
+
+ self.fwhmPointsSpin = qt.QSpinBox(self.manualFwhmGB)
+ self.fwhmPointsSpin.setRange(0, 999999)
+ self.fwhmPointsSpin.setToolTip("Typical peak fwhm (number of data points)")
+ layout2.addWidget(self.fwhmPointsSpin)
+ # ----------------------------------------------------
+
+ self.manualScalingGB = qt.QGroupBox("Define scaling manually", self)
+ self.manualScalingGB.setCheckable(True)
+ self.manualScalingGB.setToolTip(
+ "If disabled, the Y scaling used for peak search is " +
+ "estimated automatically")
+ layout.addWidget(self.manualScalingGB)
+ # ------------ GroupBox scaling-----------------------
+ layout3 = qt.QHBoxLayout(self.manualScalingGB)
+ self.manualScalingGB.setLayout(layout3)
+
+ label = qt.QLabel("Y Scaling", self.manualScalingGB)
+ layout3.addWidget(label)
+
+ self.yScalingEntry = qt.QLineEdit(self.manualScalingGB)
+ self.yScalingEntry.setToolTip(
+ "Data values will be multiplied by this value prior to peak" +
+ " search")
+ self.yScalingEntry.setValidator(qt.QDoubleValidator())
+ layout3.addWidget(self.yScalingEntry)
+ # ----------------------------------------------------
+
+ # ------------------- grid layout --------------------
+ containerWidget = qt.QWidget(self)
+ layout4 = qt.QHBoxLayout(containerWidget)
+ containerWidget.setLayout(layout4)
+
+ label = qt.QLabel("Sensitivity", containerWidget)
+ layout4.addWidget(label)
+
+ self.sensitivityEntry = qt.QLineEdit(containerWidget)
+ self.sensitivityEntry.setToolTip(
+ "Peak search sensitivity threshold, expressed as a multiple " +
+ "of the standard deviation of the noise.\nMinimum value is 1 " +
+ "(to be detected, peak must be higher than the estimated noise)")
+ sensivalidator = qt.QDoubleValidator()
+ sensivalidator.setBottom(1.0)
+ self.sensitivityEntry.setValidator(sensivalidator)
+ layout4.addWidget(self.sensitivityEntry)
+ # ----------------------------------------------------
+ layout.addWidget(containerWidget)
+
+ self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self)
+ self.forcePeakPresenceCB.setToolTip(
+ "If peak search algorithm is unsuccessful, place one peak " +
+ "at the maximum of the curve")
+ layout.addWidget(self.forcePeakPresenceCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+ self.manualFwhmGB.setChecked(
+ not default_dict.get('AutoFwhm', True))
+ self.fwhmPointsSpin.setValue(
+ default_dict.get('FwhmPoints', 8))
+ self.sensitivityEntry.setText(
+ str(default_dict.get('Sensitivity', 1.0)))
+ self.manualScalingGB.setChecked(
+ not default_dict.get('AutoScaling', False))
+ self.yScalingEntry.setText(
+ str(default_dict.get('Yscaling', 1.0)))
+ self.forcePeakPresenceCB.setChecked(
+ default_dict.get('ForcePeakPresence', False))
+
+ def get(self):
+ """Return a dictionary of peak search parameters, to be processed by
+ the :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'AutoFwhm': not self.manualFwhmGB.isChecked(),
+ 'FwhmPoints': self.fwhmPointsSpin.value(),
+ 'Sensitivity': safe_float(self.sensitivityEntry.text()),
+ 'AutoScaling': not self.manualScalingGB.isChecked(),
+ 'Yscaling': safe_float(self.yScalingEntry.text()),
+ 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked()
+ }
+ return ddict
+
+
+class BackgroundPage(qt.QGroupBox):
+ """Background subtraction configuration, specific to fittheories
+ estimation functions."""
+ def __init__(self, parent=None,
+ title="Subtract strip background prior to estimation"):
+ super(BackgroundPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setCheckable(True)
+ self.setToolTip(
+ "The strip algorithm strips away peaks to compute the " +
+ "background signal.\nAt each iteration, a sample is compared " +
+ "to the average of the two samples at a given distance in both" +
+ " directions,\n and if its value is higher than the average,"
+ "it is replaced by the average.")
+
+ layout = qt.QGridLayout(self)
+ self.setLayout(layout)
+
+ for i, label_text in enumerate(
+ ["Strip width (in samples)",
+ "Number of iterations",
+ "Strip threshold factor"]):
+ label = qt.QLabel(label_text)
+ layout.addWidget(label, i, 0)
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setToolTip(
+ "Width, in number of samples, of the strip operator")
+ self.stripWidthSpin.setRange(1, 999999)
+
+ layout.addWidget(self.stripWidthSpin, 0, 1)
+
+ self.numIterationsSpin = qt.QSpinBox(self)
+ self.numIterationsSpin.setToolTip(
+ "Number of iterations of the strip algorithm")
+ self.numIterationsSpin.setRange(1, 999999)
+ layout.addWidget(self.numIterationsSpin, 1, 1)
+
+ self.thresholdFactorEntry = qt.QLineEdit(self)
+ self.thresholdFactorEntry.setToolTip(
+ "Factor used by the strip algorithm to decide whether a sample" +
+ "value should be stripped.\nThe value must be higher than the " +
+ "average of the 2 samples at +- w times this factor.\n")
+ self.thresholdFactorEntry.setValidator(qt.QDoubleValidator())
+ layout.addWidget(self.thresholdFactorEntry, 2, 1)
+
+ self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self)
+ self.smoothStripGB.setCheckable(True)
+ self.smoothStripGB.setToolTip(
+ "Apply a smoothing before subtracting strip background" +
+ " in fit and estimate processes")
+ smoothlayout = qt.QHBoxLayout(self.smoothStripGB)
+ label = qt.QLabel("Smoothing width (Savitsky-Golay)")
+ smoothlayout.addWidget(label)
+ self.smoothingWidthSpin = qt.QSpinBox(self)
+ self.smoothingWidthSpin.setToolTip(
+ "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)")
+ self.smoothingWidthSpin.setRange(3, 101)
+ self.smoothingWidthSpin.setSingleStep(2)
+ smoothlayout.addWidget(self.smoothingWidthSpin)
+
+ layout.addWidget(self.smoothStripGB, 3, 0, 1, 2)
+
+ layout.setRowStretch(4, 1)
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+
+ self.setChecked(
+ default_dict.get('StripBackgroundFlag', True))
+
+ self.stripWidthSpin.setValue(
+ default_dict.get('StripWidth', 2))
+ self.numIterationsSpin.setValue(
+ default_dict.get('StripIterations', 5000))
+ self.thresholdFactorEntry.setText(
+ str(default_dict.get('StripThreshold', 1.0)))
+ self.smoothStripGB.setChecked(
+ default_dict.get('SmoothingFlag', False))
+ self.smoothingWidthSpin.setValue(
+ default_dict.get('SmoothingWidth', 3))
+
+ def get(self):
+ """Return a dictionary of background subtraction parameters, to be
+ processed by the :meth:`configure` method of the selected fit theory.
+ """
+ ddict = {
+ 'StripBackgroundFlag': self.isChecked(),
+ 'StripWidth': self.stripWidthSpin.value(),
+ 'StripIterations': self.numIterationsSpin.value(),
+ 'StripThreshold': safe_float(self.thresholdFactorEntry.text()),
+ 'SmoothingFlag': self.smoothStripGB.isChecked(),
+ 'SmoothingWidth': self.smoothingWidthSpin.value()
+ }
+ return ddict
+
+
+def safe_float(string_, default=1.0):
+ """Convert a string into a float.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = float(string_)
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def safe_int(string_, default=1):
+ """Convert a string into a integer.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = int(float(string_))
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def getFitConfigDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a fit configuration dialog, adapted
+ for configuring standard fit theories from
+ :mod:`silx.math.fit.fittheories`.
+
+ :return: Instance of :class:`TabsDialogData` with 3 tabs:
+ :class:`ConstraintsPage`, :class:`SearchPage` and
+ :class:`BackgroundPage`
+ """
+ tdd = TabsDialogData(parent=parent, default=default)
+ tdd.addTab(ConstraintsPage(), label="Constraints")
+ tdd.addTab(SearchPage(), label="Peak search")
+ tdd.addTab(BackgroundPage(), label="Background")
+ # apply default to newly added pages
+ tdd.setDefault()
+
+ return tdd
+
+
+def main():
+ a = qt.QApplication([])
+
+ mw = qt.QMainWindow()
+ mw.show()
+
+ tdd = getFitConfigDialog(mw, default={"a": 1})
+ tdd.show()
+ tdd.exec_()
+ print("TabsDialogData result: ", tdd.result())
+ print("TabsDialogData output: ", tdd.output)
+
+ a.exec_()
+
+if __name__ == "__main__":
+ main()
diff --git a/silx/gui/fit/FitWidget.py b/silx/gui/fit/FitWidget.py
new file mode 100644
index 0000000..a5c3cfd
--- /dev/null
+++ b/silx/gui/fit/FitWidget.py
@@ -0,0 +1,727 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module provides a widget designed to configure and run a fitting
+process with constraints on parameters.
+
+The main class is :class:`FitWidget`. It relies on
+:mod:`silx.math.fit.fitmanager`, which relies on :func:`silx.math.fit.leastsq`.
+
+The user can choose between functions before running the fit. These function can
+be user defined, or by default are loaded from
+:mod:`silx.math.fit.fittheories`.
+"""
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/02/2017"
+
+import logging
+import sys
+import traceback
+import warnings
+
+from silx.math.fit import fittheories
+from silx.math.fit import fitmanager, functions
+from silx.gui import qt
+from .FitWidgets import (FitActionsButtons, FitStatusLines,
+ FitConfigWidget, ParametersTab)
+from .FitConfig import getFitConfigDialog
+from .BackgroundWidget import getBgDialog, BackgroundDialog
+
+QTVERSION = qt.qVersion()
+DEBUG = 0
+_logger = logging.getLogger(__name__)
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class FitWidget(qt.QWidget):
+ """This widget can be used to configure, run and display results of a
+ fitting process.
+
+ The standard steps for using this widget is to initialize it, then load
+ the data to be fitted.
+
+ Optionally, you can also load user defined fit theories. If you skip this
+ step, a series of default fit functions will be presented (gaussian-like
+ functions), and you can later load your custom fit theories from an
+ external file using the GUI.
+
+ A fit theory is a fit function and its associated features:
+
+ - estimation function,
+ - list of parameter names
+ - numerical derivative algorithm
+ - configuration widget
+
+ Once the widget is up and running, the user may select a fit theory and a
+ background theory, change configuration parameters specific to the theory
+ run the estimation, set constraints on parameters and run the actual fit.
+
+ The results are displayed in a table.
+ """
+ sigFitWidgetSignal = qt.Signal(object)
+ """This signal is emitted by the estimation and fit methods.
+ It carries a dictionary with two items:
+
+ - *event*: one of the following strings
+
+ - *EstimateStarted*,
+ - *FitStarted*
+ - *EstimateFinished*,
+ - *FitFinished*
+ - *EstimateFailed*
+ - *FitFailed*
+
+ - *data*: None, or fit/estimate results (see documentation for
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+ """
+
+ def __init__(self, parent=None, title=None, fitmngr=None,
+ enableconfig=True, enablestatus=True, enablebuttons=True):
+ """
+
+ :param parent: Parent widget
+ :param title: Window title
+ :param fitmngr: User defined instance of
+ :class:`silx.math.fit.fitmanager.FitManager`, or ``None``
+ :param enableconfig: If ``True``, activate widgets to modify the fit
+ configuration (select between several fit functions or background
+ functions, apply global constraints, peak search parameters…)
+ :param enablestatus: If ``True``, add a fit status widget, to display
+ a message when fit estimation is available and when fit results
+ are available, as well as a measure of the fit error.
+ :param enablebuttons: If ``True``, add buttons to run estimation and
+ fitting.
+ """
+ if title is None:
+ title = "FitWidget"
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle(title)
+ layout = qt.QVBoxLayout(self)
+
+ self.fitmanager = self._setFitManager(fitmngr)
+ """Instance of :class:`FitManager`.
+ This is the underlying data model of this FitWidget.
+
+ If no custom theories are defined, the default ones from
+ :mod:`silx.math.fit.fittheories` are imported.
+ """
+
+ # reference fitmanager.configure method for direct access
+ self.configure = self.fitmanager.configure
+ self.fitconfig = self.fitmanager.fitconfig
+
+ self.configdialogs = {}
+ """This dictionary defines the fit configuration widgets
+ associated with the fit theories in :attr:`fitmanager.theories`
+
+ Keys must correspond to existing theory names, i.e. existing keys
+ in :attr:`fitmanager.theories`.
+
+ Values must be instances of QDialog widgets with an additional
+ *output* attribute, a dictionary storing configuration parameters
+ interpreted by the corresponding fit theory.
+
+ The dialog can also define a *setDefault* method to initialize the
+ widget values with values in a dictionary passed as a parameter.
+ This will be executed first.
+
+ In case the widget does not actually inherit :class:`QDialog`, it
+ must at least implement the following methods (executed in this
+ particular order):
+
+ - :meth:`show`: should cause the widget to become visible to the
+ user)
+ - :meth:`exec_`: should run while the user is interacting with the
+ widget, interrupting the rest of the program. It should
+ typically end (*return*) when the user clicks an *OK*
+ or a *Cancel* button.
+ - :meth:`result`: must return ``True`` if the new configuration in
+ attribute :attr:`output` is to be accepted (user clicked *OK*),
+ or return ``False`` if :attr:`output` is to be rejected (user
+ clicked *Cancel*)
+
+ To associate a custom configuration widget with a fit theory, use
+ :meth:`associateConfigDialog`. E.g.::
+
+ fw = FitWidget()
+ my_config_widget = MyGaussianConfigWidget(parent=fw)
+ fw.associateConfigDialog(theory_name="Gaussians",
+ config_widget=my_config_widget)
+ """
+
+ self.bgconfigdialogs = {}
+ """Same as :attr:`configdialogs`, except that the widget is associated
+ with a background theory in :attr:`fitmanager.bgtheories`"""
+
+ self._associateDefaultConfigDialogs()
+
+ self.guiConfig = None
+ """Configuration widget at the top of FitWidget, to select
+ fit function, background function, and open an advanced
+ configuration dialog."""
+
+ self.guiParameters = ParametersTab(self)
+ """Table widget for display of fit parameters and constraints"""
+
+ if enableconfig:
+ self.guiConfig = FitConfigWidget(self)
+ """Function selector and configuration widget"""
+
+ self.guiConfig.FunConfigureButton.clicked.connect(
+ self.__funConfigureGuiSlot)
+ self.guiConfig.BgConfigureButton.clicked.connect(
+ self.__bgConfigureGuiSlot)
+
+ self.guiConfig.WeightCheckBox.setChecked(
+ self.fitconfig.get("WeightFlag", False))
+ self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent)
+
+ self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent)
+ self.guiConfig.FunComBox.activated[str].connect(self.funEvent)
+ self._populateFunctions()
+
+ layout.addWidget(self.guiConfig)
+
+ layout.addWidget(self.guiParameters)
+
+ if enablestatus:
+ self.guistatus = FitStatusLines(self)
+ """Status bar"""
+ layout.addWidget(self.guistatus)
+
+ if enablebuttons:
+ self.guibuttons = FitActionsButtons(self)
+ """Widget with estimate, start fit and dismiss buttons"""
+ self.guibuttons.EstimateButton.clicked.connect(self.estimate)
+ self.guibuttons.StartFitButton.clicked.connect(self.startFit)
+ self.guibuttons.DismissButton.clicked.connect(self.dismiss)
+ layout.addWidget(self.guibuttons)
+
+ def _setFitManager(self, fitinstance):
+ """Initialize a :class:`FitManager` instance, to be assigned to
+ :attr:`fitmanager`, or use a custom FitManager instance.
+
+ :param fitinstance: Existing instance of FitManager, possibly
+ customized by the user, or None to load a default instance."""
+ if isinstance(fitinstance, fitmanager.FitManager):
+ # customized
+ fitmngr = fitinstance
+ else:
+ # initialize default instance
+ fitmngr = fitmanager.FitManager()
+
+ # initialize the default fitting functions in case
+ # none is present
+ if not len(fitmngr.theories):
+ fitmngr.loadtheories(fittheories)
+
+ return fitmngr
+
+ def _associateDefaultConfigDialogs(self):
+ """Fill :attr:`bgconfigdialogs` and :attr:`configdialogs` by calling
+ :meth:`associateConfigDialog` with default config dialog widgets.
+ """
+ # associate silx.gui.fit.FitConfig with all theories
+ # Users can later associate their own custom dialogs to
+ # replace the default.
+ configdialog = getFitConfigDialog(parent=self,
+ default=self.fitconfig)
+ for theory in self.fitmanager.theories:
+ self.associateConfigDialog(theory, configdialog)
+ for bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, configdialog,
+ theory_is_background=True)
+
+ # associate silx.gui.fit.BackgroundWidget with Strip and Snip
+ bgdialog = getBgDialog(parent=self,
+ default=self.fitconfig)
+ for bgtheory in ["Strip", "Snip"]:
+ if bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, bgdialog,
+ theory_is_background=True)
+
+ def _populateFunctions(self):
+ """Fill combo-boxes with fit theories and background theories
+ loaded by :attr:`fitmanager`.
+ Run :meth:`fitmanager.configure` to ensure the custom configuration
+ of the selected theory has been loaded into :attr:`fitconfig`"""
+ for theory_name in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(theory_name)
+ self.guiConfig.BkgComBox.setItemData(
+ self.guiConfig.BkgComBox.findText(theory_name),
+ self.fitmanager.bgtheories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ for theory_name in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(theory_name)
+ self.guiConfig.FunComBox.setItemData(
+ self.guiConfig.FunComBox.findText(theory_name),
+ self.fitmanager.theories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ # - activate selected fit theory (if any)
+ # - activate selected bg theory (if any)
+ configuration = self.fitmanager.configure()
+ if self.fitmanager.selectedtheory is None:
+ # take the first one by default
+ self.guiConfig.FunComBox.setCurrentIndex(1)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ else:
+ idx = list(self.fitmanager.theories).index(self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(idx + 1)
+ self.funEvent(self.fitmanager.selectedtheory)
+
+ if self.fitmanager.selectedbg is None:
+ self.guiConfig.BkgComBox.setCurrentIndex(1)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+ else:
+ idx = list(self.fitmanager.bgtheories).index(self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(idx + 1)
+ self.bkgEvent(self.fitmanager.selectedbg)
+
+ configuration.update(self.configure())
+
+ def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
+ warnings.warn("Method renamed to setData",
+ DeprecationWarning)
+ self.setData(x, y, sigmay, xmin, xmax)
+
+ def setData(self, x, y, sigmay=None, xmin=None, xmax=None):
+ """Set data to be fitted.
+
+ :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
+ ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
+ :type x: Sequence or numpy array or None
+ :param y: The dependant data ``y = f(x)``. ``y`` must have the same
+ shape as ``x`` if ``x`` is not ``None``.
+ :type y: Sequence or numpy array or None
+ :param sigmay: The uncertainties in the ``ydata`` array. These are
+ used as weights in the least-squares problem.
+ If ``None``, the uncertainties are assumed to be 1.
+ :type sigmay: Sequence or numpy array or None
+ :param xmin: Lower value of x values to use for fitting
+ :param xmax: Upper value of x values to use for fitting
+ """
+ self.fitmanager.setdata(x=x, y=y, sigmay=sigmay,
+ xmin=xmin, xmax=xmax)
+ for config_dialog in self.bgconfigdialogs.values():
+ if isinstance(config_dialog, BackgroundDialog):
+ config_dialog.setData(x, y, xmin=xmin, xmax=xmax)
+
+ def associateConfigDialog(self, theory_name, config_widget,
+ theory_is_background=False):
+ """Associate an instance of custom configuration dialog widget to
+ a fit theory or to a background theory.
+
+ This adds or modifies an item in the correspondence table
+ :attr:`configdialogs` or :attr:`bgconfigdialogs`.
+
+ :param str theory_name: Name of fit theory. This must be a key of dict
+ :attr:`fitmanager.theories`
+ :param config_widget: Custom configuration widget. See documentation
+ for :attr:`configdialogs`
+ :param bool theory_is_background: If flag is *True*, add dialog to
+ :attr:`bgconfigdialogs` rather than :attr:`configdialogs`
+ (default).
+ :raise: KeyError if parameter ``theory_name`` does not match an
+ existing fit theory or background theory in :attr:`fitmanager`.
+ :raise: AttributeError if the widget does not implement the mandatory
+ methods (*show*, *exec_*, *result*, *setDefault*) or the mandatory
+ attribute (*output*).
+ """
+ theories = self.fitmanager.bgtheories if theory_is_background else\
+ self.fitmanager.theories
+
+ if theory_name not in theories:
+ raise KeyError("%s does not match an existing fitmanager theory")
+
+ if config_widget is not None:
+ for mandatory_attr in ["show", "exec_", "result", "output"]:
+ if not hasattr(config_widget, mandatory_attr):
+ raise AttributeError(
+ "Custom configuration widget must define " +
+ "attribute or method " + mandatory_attr)
+
+ if theory_is_background:
+ self.bgconfigdialogs[theory_name] = config_widget
+ else:
+ self.configdialogs[theory_name] = config_widget
+
+ def _emitSignal(self, ddict):
+ """Emit pyqtSignal after estimation completed
+ (``ddict = {'event': 'EstimateFinished', 'data': fit_results}``)
+ and after fit completed
+ (``ddict = {'event': 'FitFinished', 'data': fit_results}``)"""
+ self.sigFitWidgetSignal.emit(ddict)
+
+ def __funConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="function")
+
+ def __bgConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="background")
+
+ def __configureGui(self, newconfiguration=None, dialog_type="function"):
+ """Open an advanced configuration dialog widget to get a configuration
+ dictionary, or use a supplied configuration dictionary. Call
+ :meth:`configure` with this dictionary as a parameter. Update the gui
+ accordingly. Reinitialize the fit results in the table and in
+ :attr:`fitmanager`.
+
+ :param newconfiguration: User supplied configuration dictionary. If ``None``,
+ open a dialog widget that returns a dictionary."""
+ configuration = self.configure()
+ # get new dictionary
+ if newconfiguration is None:
+ newconfiguration = self.configureDialog(configuration, dialog_type)
+ # update configuration
+ configuration.update(self.configure(**newconfiguration))
+ # set fit function theory
+ try:
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.funEvent(self.fitmanager.selectedtheory)
+ except ValueError:
+ _logger.error("Function not in list %s",
+ self.fitmanager.selectedtheory)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ # current background
+ try:
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.bkgEvent(self.fitmanager.selectedbg)
+ except ValueError:
+ _logger.error("Background not in list %s",
+ self.fitmanager.selectedbg)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+
+ # update the Gui
+ self.__initialParameters()
+
+ def configureDialog(self, oldconfiguration, dialog_type="function"):
+ """Display a dialog, allowing the user to define fit configuration
+ parameters.
+
+ By default, a common dialog is used for all fit theories. But if the
+ defined a custom dialog using :meth:`associateConfigDialog`, it is
+ used instead.
+
+ :param dict oldconfiguration: Dictionary containing previous configuration
+ :param str dialog_type: "function" or "background"
+ :return: User defined parameters in a dictionary
+ """
+ newconfiguration = {}
+ newconfiguration.update(oldconfiguration)
+
+ if dialog_type == "function":
+ theory = self.fitmanager.selectedtheory
+ configdialog = self.configdialogs[theory]
+ elif dialog_type == "background":
+ theory = self.fitmanager.selectedbg
+ configdialog = self.bgconfigdialogs[theory]
+
+ # this should only happen if a user specifically associates None
+ # with a theory, to have no configuration option
+ if configdialog is None:
+ return {}
+
+ # update state of configdialog before showing it
+ if hasattr(configdialog, "setDefault"):
+ configdialog.setDefault(newconfiguration)
+ configdialog.show()
+ configdialog.exec_()
+ if configdialog.result():
+ newconfiguration.update(configdialog.output)
+
+ return newconfiguration
+
+ def estimate(self):
+ """Run parameter estimation function then emit
+ :attr:`sigFitWidgetSignal` with a dictionary containing a status
+ message and a list of fit parameters estimations
+ in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'EstimateStarted'*
+ - *'EstimateFailed'*
+ - *'EstimateFinished'*
+ """
+ try:
+ theory_name = self.fitmanager.selectedtheory
+ estimation_function = self.fitmanager.theories[theory_name].estimate
+ if estimation_function is not None:
+ ddict = {'event': 'EstimateStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.estimate(callback=self.fitStatus)
+ else:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Information)
+ text = "Function does not define a way to estimate\n"
+ text += "the initial parameters. Please, fill them\n"
+ text += "yourself in the table and press Start Fit\n"
+ msg.setText(text)
+ msg.setWindowTitle('FitWidget Message')
+ msg.exec_()
+ return
+ except: # noqa (we want to catch and report all errors)
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Error on estimate: %s" % traceback.format_exc())
+ msg.exec_()
+ ddict = {
+ 'event': 'EstimateFailed',
+ 'data': None}
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'EstimateFinished',
+ 'data': self.fitmanager.fit_results}
+ self._emitSignal(ddict)
+
+ def startfit(self):
+ warnings.warn("Method renamed to startFit",
+ DeprecationWarning)
+ self.startFit()
+
+ def startFit(self):
+ """Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary
+ containing a status message and a list of fit
+ parameters results in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'FitStarted'*
+ - *'FitFailed'*
+ - *'FitFinished'*
+ """
+ self.fitmanager.fit_results = self.guiParameters.getFitResults()
+ try:
+ ddict = {'event': 'FitStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.runfit(callback=self.fitStatus)
+ except: # noqa (we want to catch and report all errors)
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Error on Fit: %s" % traceback.format_exc())
+ msg.exec_()
+ ddict = {
+ 'event': 'FitFailed',
+ 'data': None
+ }
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'FitFinished',
+ 'data': self.fitmanager.fit_results
+ }
+ self._emitSignal(ddict)
+ return
+
+ def bkgEvent(self, bgtheory):
+ """Select background theory, then reinitialize parameters"""
+ bgtheory = str(bgtheory)
+ if bgtheory in self.fitmanager.bgtheories:
+ self.fitmanager.setbackground(bgtheory)
+ else:
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadbgtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.BkgComBox.count() > 1:
+ self.guiConfig.BkgComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def funEvent(self, theoryname):
+ """Select a fit theory to be used for fitting. If this theory exists
+ in :attr:`fitmanager`, use it. Then, reinitialize table.
+
+ :param theoryname: Name of the fit theory to use for fitting. If this theory
+ exists in :attr:`fitmanager`, use it. Else, open a file dialog to open
+ a custom fit function definition file with
+ :meth:`fitmanager.loadtheories`.
+ """
+ theoryname = str(theoryname)
+ if theoryname in self.fitmanager.theories:
+ self.fitmanager.settheory(theoryname)
+ else:
+ # open a load file dialog
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.FunComBox.count() > 1:
+ self.guiConfig.FunComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def weightEvent(self, flag):
+ """This is called when WeightCheckBox is clicked, to configure the
+ *WeightFlag* field in :attr:`fitmanager.fitconfig` and set weights
+ in the least-square problem."""
+ self.configure(WeightFlag=flag)
+ if flag:
+ self.fitmanager.enableweight()
+ else:
+ # set weights back to 1
+ self.fitmanager.disableweight()
+
+ def __initialParameters(self):
+ """Fill the fit parameters names with names of the parameters of
+ the selected background theory and the selected fit theory.
+ Initialize :attr:`fitmanager.fit_results` with these names, and
+ initialize the table with them. This creates a view called "Fit"
+ in :attr:`guiParameters`"""
+ self.fitmanager.parameter_names = []
+ self.fitmanager.fit_results = []
+ for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters:
+ self.fitmanager.parameter_names.append(pname)
+ self.fitmanager.fit_results.append({'name': pname,
+ 'estimation': 0,
+ 'group': 0,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+ if self.fitmanager.selectedtheory is not None:
+ theory = self.fitmanager.selectedtheory
+ for pname in self.fitmanager.theories[theory].parameters:
+ self.fitmanager.parameter_names.append(pname + "1")
+ self.fitmanager.fit_results.append({'name': pname + "1",
+ 'estimation': 0,
+ 'group': 1,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+
+ def fitStatus(self, data):
+ """Set *status* and *chisq* in status bar"""
+ if 'chisq' in data:
+ if data['chisq'] is None:
+ self.guistatus.ChisqLine.setText(" ")
+ else:
+ chisq = data['chisq']
+ self.guistatus.ChisqLine.setText("%6.2f" % chisq)
+
+ if 'status' in data:
+ status = data['status']
+ self.guistatus.StatusLine.setText(str(status))
+
+ def dismiss(self):
+ """Close FitWidget"""
+ self.close()
+
+
+if __name__ == "__main__":
+ import numpy
+
+ x = numpy.arange(1500).astype(numpy.float)
+ constant_bg = 3.14
+
+ p = [1000, 100., 30.0,
+ 500, 300., 25.,
+ 1700, 500., 35.,
+ 750, 700., 30.0,
+ 1234, 900., 29.5,
+ 302, 1100., 30.5,
+ 75, 1300., 21.]
+ y = functions.sum_gauss(x, *p) + constant_bg
+
+ a = qt.QApplication(sys.argv)
+ w = FitWidget()
+ w.setData(x=x, y=y)
+ w.show()
+ a.exec_()
diff --git a/silx/gui/fit/FitWidgets.py b/silx/gui/fit/FitWidgets.py
new file mode 100644
index 0000000..408666b
--- /dev/null
+++ b/silx/gui/fit/FitWidgets.py
@@ -0,0 +1,559 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""Collection of widgets used to build
+:class:`silx.gui.fit.FitWidget.FitWidget`"""
+
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.fit.Parameters import Parameters
+
+QTVERSION = qt.qVersion()
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+class FitActionsButtons(qt.QWidget):
+ """Widget with 3 ``QPushButton``:
+
+ The buttons can be accessed as public attributes::
+
+ - ``EstimateButton``
+ - ``StartFitButton``
+ - ``DismissButton``
+
+ You will typically need to access these attributes to connect the buttons
+ to actions. For instance, if you have 3 functions ``estimate``,
+ ``runfit`` and ``dismiss``, you can connect them like this::
+
+ >>> fit_actions_buttons = FitActionsButtons()
+ >>> fit_actions_buttons.EstimateButton.clicked.connect(estimate)
+ >>> fit_actions_buttons.StartFitButton.clicked.connect(runfit)
+ >>> fit_actions_buttons.DismissButton.clicked.connect(dismiss)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(234, 53)
+
+ grid_layout = qt.QGridLayout(self)
+ grid_layout.setContentsMargins(11, 11, 11, 11)
+ grid_layout.setSpacing(6)
+ layout = qt.QHBoxLayout(None)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.EstimateButton = qt.QPushButton(self)
+ self.EstimateButton.setText("Estimate")
+ layout.addWidget(self.EstimateButton)
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer)
+
+ self.StartFitButton = qt.QPushButton(self)
+ self.StartFitButton.setText("Start Fit")
+ layout.addWidget(self.StartFitButton)
+ spacer_2 = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer_2)
+
+ self.DismissButton = qt.QPushButton(self)
+ self.DismissButton.setText("Dismiss")
+ layout.addWidget(self.DismissButton)
+
+ grid_layout.addLayout(layout, 0, 0)
+
+
+class FitStatusLines(qt.QWidget):
+ """Widget with 2 greyed out write-only ``QLineEdit``.
+
+ These text widgets can be accessed as public attributes::
+
+ - ``StatusLine``
+ - ``ChisqLine``
+
+ You will typically need to access these widgets to update the displayed
+ text::
+
+ >>> fit_status_lines = FitStatusLines()
+ >>> fit_status_lines.StatusLine.setText("Ready")
+ >>> fit_status_lines.ChisqLine.setText("%6.2f" % 0.01)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(535, 47)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.StatusLabel = qt.QLabel(self)
+ self.StatusLabel.setText("Status:")
+ layout.addWidget(self.StatusLabel)
+
+ self.StatusLine = qt.QLineEdit(self)
+ self.StatusLine.setText("Ready")
+ self.StatusLine.setReadOnly(1)
+ layout.addWidget(self.StatusLine)
+
+ self.ChisqLabel = qt.QLabel(self)
+ self.ChisqLabel.setText("Reduced chisq:")
+ layout.addWidget(self.ChisqLabel)
+
+ self.ChisqLine = qt.QLineEdit(self)
+ self.ChisqLine.setMaximumSize(qt.QSize(16000, 32767))
+ self.ChisqLine.setText("")
+ self.ChisqLine.setReadOnly(1)
+ layout.addWidget(self.ChisqLine)
+
+
+class FitConfigWidget(qt.QWidget):
+ """Widget whose purpose is to select a fit theory and a background
+ theory, load a new fit theory definition file and provide
+ a "Configure" button to open an advanced configuration dialog.
+
+ This is used in :class:`silx.gui.fit.FitWidget.FitWidget`, to offer
+ an interface to quickly modify the main parameters prior to running a fit:
+
+ - select a fitting function through :attr:`FunComBox`
+ - select a background function through :attr:`BkgComBox`
+ - open a dialog for modifying advanced parameters through
+ :attr:`FunConfigureButton`
+ """
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle("FitConfigGUI")
+
+ layout = qt.QGridLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.FunLabel = qt.QLabel(self)
+ self.FunLabel.setText("Function")
+ layout.addWidget(self.FunLabel, 0, 0)
+
+ self.FunComBox = qt.QComboBox(self)
+ self.FunComBox.addItem("Add Function(s)")
+ self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"),
+ "Load fit theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.FunComBox, 0, 1)
+
+ self.BkgLabel = qt.QLabel(self)
+ self.BkgLabel.setText("Background")
+ layout.addWidget(self.BkgLabel, 1, 0)
+
+ self.BkgComBox = qt.QComboBox(self)
+ self.BkgComBox.addItem("Add Background(s)")
+ self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"),
+ "Load background theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.BkgComBox, 1, 1)
+
+ self.FunConfigureButton = qt.QPushButton(self)
+ self.FunConfigureButton.setText("Configure")
+ self.FunConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected function")
+ layout.addWidget(self.FunConfigureButton, 0, 2)
+
+ self.BgConfigureButton = qt.QPushButton(self)
+ self.BgConfigureButton.setText("Configure")
+ self.BgConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected background")
+ layout.addWidget(self.BgConfigureButton, 1, 2)
+
+ self.WeightCheckBox = qt.QCheckBox(self)
+ self.WeightCheckBox.setText("Weighted fit")
+ self.WeightCheckBox.setToolTip(
+ "Enable usage of weights in the least-square problem.\n Use" +
+ " the uncertainties (sigma) if provided, else use sqrt(y).")
+
+ layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1)
+
+ layout.setColumnStretch(4, 1)
+
+
+class ParametersTab(qt.QTabWidget):
+ """This widget provides tabs to display and modify fit parameters. Each
+ tab contains a table with fit data such as parameter names, estimated
+ values, fit constraints, and final fit results.
+
+ The usual way to initialize the table is to fill it with the fit
+ parameters from a :class:`silx.math.fit.fitmanager.FitManager` object, after
+ the estimation process or after the final fit.
+
+ In the following example we use a :class:`ParametersTab` to display the
+ results of two separate fits::
+
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui import qt
+ import numpy
+
+ a = qt.QApplication([])
+
+ # Create synthetic data
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(theory="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in a tab in our widget
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ # new synthetic data
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setData(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(theory="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in another tab in our widget
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+ a.exec_()
+
+ """
+
+ def __init__(self, parent=None, name="FitParameters"):
+ """
+
+ :param parent: Parent widget
+ :param name: Widget title
+ """
+ qt.QTabWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.views = OrderedDict()
+ """Dictionary of views. Keys are view names,
+ items are :class:`Parameters` widgets"""
+
+ self.latest_view = None
+ """Name of latest view"""
+
+ # the widgets/tables themselves
+ self.tables = {}
+ """Dictionary of :class:`silx.gui.fit.parameters.Parameters` objects.
+ These objects store fit results
+ """
+
+ self.setContentsMargins(10, 10, 10, 10)
+
+ def setView(self, view=None, fitresults=None):
+ """Add or update a table. Fill it with data from a fit
+
+ :param view: Tab name to be added or updated. If ``None``, use the
+ latest view.
+ :param fitresults: Fit data to be added to the table
+ :raise: KeyError if no view name specified and no latest view
+ available.
+ """
+ if view is None:
+ if self.latest_view is not None:
+ view = self.latest_view
+ else:
+ raise KeyError(
+ "No view available. You must specify a view" +
+ " name the first time you call this method."
+ )
+
+ if view in self.tables.keys():
+ table = self.tables[view]
+ else:
+ # create the parameters instance
+ self.tables[view] = Parameters(self)
+ table = self.tables[view]
+ self.views[view] = table
+ self.addTab(table, str(view))
+
+ if fitresults is not None:
+ table.fillFromFit(fitresults)
+
+ self.setCurrentWidget(self.views[view])
+ self.latest_view = view
+
+ def renameView(self, oldname=None, newname=None):
+ """Rename a view (tab)
+
+ :param oldname: Name of the view to be renamed
+ :param newname: New name of the view"""
+ error = 1
+ if newname is not None:
+ if newname not in self.views.keys():
+ if oldname in self.views.keys():
+ parameterlist = self.tables[oldname].getFitResults()
+ self.setView(view=newname, fitresults=parameterlist)
+ self.removeView(oldname)
+ error = 0
+ return error
+
+ def fillFromFit(self, fitparameterslist, view=None):
+ """Update a view with data from a fit (alias for :meth:`setView`)
+
+ :param view: Tab name to be added or updated (default: latest view)
+ :param fitparameterslist: Fit data to be added to the table
+ """
+ self.setView(view=view, fitresults=fitparameterslist)
+
+ def getFitResults(self, name=None):
+ """Call :meth:`getFitResults` for the
+ :class:`silx.gui.fit.parameters.Parameters` corresponding to the
+ latest table or to the named table (if ``name`` is not
+ ``None``). This return a list of dictionaries in the format used by
+ :class:`silx.math.fit.fitmanager.FitManager` to store fit parameter
+ results.
+
+ :param name: View name.
+ """
+ if name is None:
+ name = self.latest_view
+ return self.tables[name].getFitResults()
+
+ def removeView(self, name):
+ """Remove a view by name.
+
+ :param name: View name.
+ """
+ if name in self.views:
+ index = self.indexOf(self.tables[name])
+ self.removeTab(index)
+ index = self.indexOf(self.views[name])
+ self.removeTab(index)
+ del self.tables[name]
+ del self.views[name]
+
+ def removeAllViews(self, keep=None):
+ """Remove all views, except the one specified (argument
+ ``keep``)
+
+ :param keep: Name of the view to be kept."""
+ for view in self.tables:
+ if view != keep:
+ self.removeView(view)
+
+ def getHtmlText(self, name=None):
+ """Return the table data as HTML
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ lemon = ("#%x%x%x" % (255, 250, 205)).upper()
+ hcolor = ("#%x%x%x" % (230, 240, 249)).upper()
+ text = ""
+ text += "<nobr>"
+ text += "<table>"
+ text += "<tr>"
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += ('<td align="left" bgcolor="%s"><b>' % hcolor)
+ if QTVERSION < '4.0.0':
+ text += (str(table.horizontalHeader().label(l)))
+ else:
+ text += (str(table.horizontalHeaderItem(l).text()))
+ text += "</b></td>"
+ text += "</tr>"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ text += "<tr>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ color = "white"
+ b = "<b>"
+ else:
+ b = ""
+ color = lemon
+ try:
+ # MyQTable item has color defined
+ cc = table.item(r, 0).color
+ cc = ("#%x%x%x" % (cc.red(), cc.green(), cc.blue())).upper()
+ color = cc
+ except:
+ pass
+ for c in range(ncols):
+ item = table.item(r, c)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ finalcolor = color
+ else:
+ finalcolor = "white"
+ if c < 2:
+ text += ('<td align="left" bgcolor="%s">%s' %
+ (finalcolor, b))
+ else:
+ text += ('<td align="right" bgcolor="%s">%s' %
+ (finalcolor, b))
+ text += newtext
+ if len(b):
+ text += "</td>"
+ else:
+ text += "</b></td>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ text += "</b>"
+ text += "</tr>"
+ text += "\n"
+ text += "</table>"
+ text += "</nobr>"
+ return text
+
+ def getText(self, name=None):
+ """Return the table data as CSV formatted text, using tabulation
+ characters as separators.
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ text = ""
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += (str(table.horizontalHeaderItem(l).text())) + "\t"
+ text += "\n"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ for c in range(ncols):
+ newtext = ""
+ if c != 4:
+ item = table.item(r, c)
+ if item is not None:
+ newtext = str(item.text())
+ else:
+ item = table.cellWidget(r, c)
+ if item is not None:
+ newtext = str(item.currentText())
+ text += newtext + "\t"
+ text += "\n"
+ text += "\n"
+ return text
+
+
+def test():
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui.plot.PlotWindow import PlotWindow
+ import numpy
+
+ a = qt.QApplication([])
+
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(name="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setdata(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(name="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+
+ # Plot
+ pw = PlotWindow(control=True)
+ pw.addCurve(x, y1, "Gaussians")
+ pw.addCurve(x, y2, "Asymetric gaussians")
+ pw.show()
+
+ a.exec_()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/silx/gui/fit/Parameters.py b/silx/gui/fit/Parameters.py
new file mode 100644
index 0000000..62e3278
--- /dev/null
+++ b/silx/gui/fit/Parameters.py
@@ -0,0 +1,882 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines a table widget that is specialized in displaying fit
+parameter results and associated constraints."""
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/11/2016"
+
+import sys
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+def float_else_zero(sstring):
+ """Return converted string to float. If conversion fail, return zero.
+
+ :param sstring: String to be converted
+ :return: ``float(sstrinq)`` if ``sstring`` can be converted to float
+ (e.g. ``"3.14"``), else ``0``
+ """
+ try:
+ return float(sstring)
+ except ValueError:
+ return 0
+
+
+class QComboTableItem(qt.QComboBox):
+ """:class:`qt.QComboBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the value is
+ changed.
+
+ This signal can be used to locate the modified combo box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QComboBox`` is activated.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QComboBox.__init__(self, parent)
+ self.activated[int].connect(self._cellChanged)
+
+ def _cellChanged(self, idx): # noqa
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class QCheckBoxItem(qt.QCheckBox):
+ """:class:`qt.QCheckBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the check box has
+ been clicked on.
+
+ This signal can be used to locate the modified check box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QCheckBox`` is clicked.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QCheckBox.__init__(self, parent)
+ self.clicked.connect(self._cellChanged)
+
+ def _cellChanged(self):
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class Parameters(TableWidget):
+ """:class:`TableWidget` customized to display fit results
+ and to interact with :class:`FitManager` objects.
+
+ Data and references to cell widgets are kept in a dictionary
+ attribute :attr:`parameters`.
+
+ :param parent: Parent widget
+ :param labels: Column headers. If ``None``, default headers will be used.
+ :type labels: List of strings or None
+ :param paramlist: List of fit parameters to be displayed for each fitted
+ peak.
+ :type paramlist: list[str] or None
+ """
+ def __init__(self, parent=None, paramlist=None):
+ TableWidget.__init__(self, parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma',
+ 'Constraints', 'Min/Parame', 'Max/Factor/Delta']
+ tooltips = ["Fit parameter name",
+ "Estimated value for fit parameter. You can edit this column.",
+ "Actual value for parameter, after fit",
+ "Uncertainty (same unit as the parameter)",
+ "Constraint to be applied to the parameter for fit",
+ "First parameter for constraint (name of another param or min value)",
+ "Second parameter for constraint (max value, or factor/delta)"]
+
+ self.columnKeys = ['name', 'estimation', 'fitresult',
+ 'sigma', 'code', 'val1', 'val2']
+ """This list assigns shorter keys to refer to columns than the
+ displayed labels."""
+
+ self.__configuring = False
+
+ # column headers and associated tooltips
+ self.setColumnCount(len(labels))
+
+ for i, label in enumerate(labels):
+ item = self.horizontalHeaderItem(i)
+ if item is None:
+ item = qt.QTableWidgetItem(label,
+ qt.QTableWidgetItem.Type)
+ self.setHorizontalHeaderItem(i, item)
+
+ item.setText(label)
+ if tooltips is not None:
+ item.setToolTip(tooltips[i])
+
+ # resize columns
+ for col_key in ["name", "estimation", "sigma", "val1", "val2"]:
+ col_idx = self.columnIndexByField(col_key)
+ self.resizeColumnToContents(col_idx)
+
+ # Initialize the table with one line per supplied parameter
+ paramlist = paramlist if paramlist is not None else []
+ self.parameters = OrderedDict()
+ """This attribute stores all the data in an ordered dictionary.
+ New data can be added using :meth:`newParameterLine`.
+ Existing data can be modified using :meth:`configureLine`
+
+ Keys of the dictionary are:
+
+ - 'name': parameter name
+ - 'line': line index for the parameter in the table
+ - 'estimation'
+ - 'fitresult'
+ - 'sigma'
+ - 'code': constraint code (one of the elements of
+ :attr:`code_options`)
+ - 'val1': first parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'val2': second parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'cons1': scalar representation of 'val1'
+ (e.g. when val1 is the name of a fit parameter, cons1
+ will be the line index of this parameter)
+ - 'cons2': scalar representation of 'val2'
+ - 'vmin': equal to 'val1' when 'code' is "QUOTED"
+ - 'vmax': equal to 'val2' when 'code' is "QUOTED"
+ - 'relatedto': name of related parameter when this parameter
+ is constrained to another parameter (same as 'val1')
+ - 'factor': same as 'val2' when 'code' is 'FACTOR'
+ - 'delta': same as 'val2' when 'code' is 'DELTA'
+ - 'sum': same as 'val2' when 'code' is 'SUM'
+ - 'group': group index for the parameter
+ - 'xmin': data range minimum
+ - 'xmax': data range maximum
+ """
+ for line, param in enumerate(paramlist):
+ self.newParameterLine(param, line)
+
+ self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED",
+ "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"]
+ """Possible values in the combo boxes in the 'Constraints' column.
+ """
+
+ # connect signal
+ self.cellChanged[int, int].connect(self.onCellChanged)
+
+ def newParameterLine(self, param, line):
+ """Add a line to the :class:`QTableWidget`.
+
+ Each line represents one of the fit parameters for one of
+ the fitted peaks.
+
+ :param param: Name of the fit parameter
+ :type param: str
+ :param line: 0-based line index
+ :type line: int
+ """
+ # get current number of lines
+ nlines = self.rowCount()
+ self.__configuring = True
+ if line >= nlines:
+ self.setRowCount(line + 1)
+
+ # default configuration for fit parameters
+ self.parameters[param] = OrderedDict((('line', line),
+ ('estimation', '0'),
+ ('fitresult', ''),
+ ('sigma', ''),
+ ('code', 'FREE'),
+ ('val1', ''),
+ ('val2', ''),
+ ('cons1', 0),
+ ('cons2', 0),
+ ('vmin', '0'),
+ ('vmax', '1'),
+ ('relatedto', ''),
+ ('factor', '1.0'),
+ ('delta', '0.0'),
+ ('sum', '0.0'),
+ ('group', ''),
+ ('name', param),
+ ('xmin', None),
+ ('xmax', None)))
+ self.setReadWrite(param, 'estimation')
+ self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2'])
+
+ # Constraint codes
+ a = []
+ for option in self.code_options:
+ a.append(option)
+
+ code_column_index = self.columnIndexByField('code')
+ cellWidget = self.cellWidget(line, code_column_index)
+ if cellWidget is None:
+ cellWidget = QComboTableItem(self, row=line,
+ col=code_column_index)
+ cellWidget.addItems(a)
+ self.setCellWidget(line, code_column_index, cellWidget)
+ cellWidget.sigCellChanged[int, int].connect(self.onCellChanged)
+ self.parameters[param]['code_item'] = cellWidget
+ self.parameters[param]['relatedto_item'] = None
+ self.__configuring = False
+
+ def columnIndexByField(self, field):
+ """
+
+ :param field: Field name (column key)
+ :return: Index of the column with this field name
+ """
+ return self.columnKeys.index(field)
+
+ def fillFromFit(self, fitresults):
+ """Fill table with values from a list of dictionaries
+ (see :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+
+ :param fitresults: List of parameters as recorded
+ in the ``paramlist`` attribute of a :class:`FitManager` object
+ :type fitresults: list[dict]
+ """
+ self.setRowCount(len(fitresults))
+
+ # Reinitialize and fill self.parameters
+ self.parameters = OrderedDict()
+ for (line, param) in enumerate(fitresults):
+ self.newParameterLine(param['name'], line)
+
+ for param in fitresults:
+ name = param['name']
+ code = str(param['code'])
+ if code not in self.code_options:
+ # convert code from int to descriptive string
+ code = self.code_options[int(code)]
+ val1 = param['cons1']
+ val2 = param['cons2']
+ estimation = param['estimation']
+ group = param['group']
+ sigma = param['sigma']
+ fitresult = param['fitresult']
+
+ xmin = param.get('xmin')
+ xmax = param.get('xmax')
+
+ self.configureLine(name=name,
+ code=code,
+ val1=val1, val2=val2,
+ estimation=estimation,
+ fitresult=fitresult,
+ sigma=sigma,
+ group=group,
+ xmin=xmin, xmax=xmax)
+
+ def getConfiguration(self):
+ """Return ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ return {'parameters': self.getFitResults()}
+
+ def setConfiguration(self, ddict):
+ """Fill table with values from a ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ self.fillFromFit(ddict['parameters'])
+
+ def getFitResults(self):
+ """Return fit parameters as a list of dictionaries in the format used
+ by :class:`FitManager` (attribute ``paramlist``).
+ """
+ fitparameterslist = []
+ for param in self.parameters:
+ fitparam = {}
+ name = param
+ estimation, [code, cons1, cons2] = self.getEstimationConstraints(name)
+ buf = str(self.parameters[param]['fitresult'])
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+ if len(buf):
+ fitresult = float(buf)
+ else:
+ fitresult = 0.0
+ buf = str(self.parameters[param]['sigma'])
+ if len(buf):
+ sigma = float(buf)
+ else:
+ sigma = 0.0
+ buf = str(self.parameters[param]['group'])
+ if len(buf):
+ group = float(buf)
+ else:
+ group = 0
+ fitparam['name'] = name
+ fitparam['estimation'] = estimation
+ fitparam['fitresult'] = fitresult
+ fitparam['sigma'] = sigma
+ fitparam['group'] = group
+ fitparam['code'] = code
+ fitparam['cons1'] = cons1
+ fitparam['cons2'] = cons2
+ fitparam['xmin'] = xmin
+ fitparam['xmax'] = xmax
+ fitparameterslist.append(fitparam)
+ return fitparameterslist
+
+ def onCellChanged(self, row, col):
+ """Slot called when ``cellChanged`` signal is emitted.
+ Checks the validity of the new text in the cell, then calls
+ :meth:`configureLine` to update the internal ``self.parameters``
+ dictionary.
+
+ :param row: Row number of the changed cell (0-based index)
+ :param col: Column number of the changed cell (0-based index)
+ """
+ if (col != self.columnIndexByField("code")) and (col != -1):
+ if row != self.currentRow():
+ return
+ if col != self.currentColumn():
+ return
+ if self.__configuring:
+ return
+ param = list(self.parameters)[row]
+ field = self.columnKeys[col]
+ oldvalue = self.parameters[param][field]
+ if col != 4:
+ item = self.item(row, col)
+ if item is not None:
+ newvalue = item.text()
+ else:
+ newvalue = ''
+ else:
+ # this is the combobox
+ widget = self.cellWidget(row, col)
+ newvalue = widget.currentText()
+ if self.validate(param, field, oldvalue, newvalue):
+ paramdict = {"name": param, field: newvalue}
+ self.configureLine(**paramdict)
+ else:
+ if field == 'code':
+ # New code not valid, try restoring the old one
+ index = self.code_options.index(oldvalue)
+ self.__configuring = True
+ try:
+ self.parameters[param]['code_item'].setCurrentIndex(index)
+ finally:
+ self.__configuring = False
+ else:
+ paramdict = {"name": param, field: oldvalue}
+ self.configureLine(**paramdict)
+
+ def validate(self, param, field, oldvalue, newvalue):
+ """Check validity of ``newvalue`` when a cell's value is modified.
+
+ :param param: Fit parameter name
+ :param field: Column name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: True if new cell value is valid, else False
+ """
+ if field == 'code':
+ return self.setCodeValue(param, oldvalue, newvalue)
+ # FIXME: validate() shouldn't have side effects. Move this bit to configureLine()?
+ if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']:
+ _, candidates = self.getRelatedCandidates(param)
+ # We expect val1 to be a fit parameter name
+ if str(newvalue) in candidates:
+ return True
+ else:
+ return False
+ # except for code, val1 and name (which is read-only and does not need
+ # validation), all fields must always be convertible to float
+ else:
+ try:
+ float(str(newvalue))
+ except ValueError:
+ return False
+ return True
+
+ def setCodeValue(self, param, oldvalue, newvalue):
+ """Update 'code' and 'relatedto' fields when code cell is
+ changed.
+
+ :param param: Fit parameter name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: ``True`` if code was successfully updated
+ """
+
+ if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']:
+ self.configureLine(name=param,
+ code=newvalue)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+ elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']:
+ # I should check here that some parameter is set
+ best, candidates = self.getRelatedCandidates(param)
+ if len(candidates) == 0:
+ return False
+ self.configureLine(name=param,
+ code=newvalue,
+ relatedto=best)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+
+ elif str(newvalue) == 'IGNORE':
+ # I should check if the group can be ignored
+ # for the time being I just fix all of them to ignore
+ group = int(float(str(self.parameters[param]['group'])))
+ candidates = []
+ for param in self.parameters.keys():
+ if group == int(float(str(self.parameters[param]['group']))):
+ candidates.append(param)
+ # print candidates
+ # I should check here if there is any relation to them
+ for param in candidates:
+ self.configureLine(name=param,
+ code=newvalue)
+ return True
+ elif str(newvalue) == 'ADD':
+ group = int(float(str(self.parameters[param]['group'])))
+ if group == 0:
+ # One cannot add a background group
+ return False
+ i = 0
+ for param in self.parameters:
+ if i <= int(float(str(self.parameters[param]['group']))):
+ i += 1
+ if (group == 0) and (i == 1): # FIXME: why +1?
+ i += 1
+ self.addGroup(i, group)
+ return False
+ elif str(newvalue) == 'SHOW':
+ print(self.getEstimationConstraints(param))
+ return False
+
+ def addGroup(self, newg, gtype):
+ """Add a fit parameter group with the same fit parameters as an
+ existing group.
+
+ This function is called when the user selects "ADD" in the
+ "constraints" combobox.
+
+ :param int newg: New group number
+ :param int gtype: Group number whose parameters we want to copy
+
+ """
+ newparam = []
+ # loop through parameters until we encounter group number `gtype`
+ for param in list(self.parameters):
+ paramgroup = int(float(str(self.parameters[param]['group'])))
+ # copy parameter names in group number `gtype`
+ if paramgroup == gtype:
+ # but replace `gtype` with `newg`
+ newparam.append(param.rstrip("0123456789") + "%d" % newg)
+
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+
+ # Add new parameters (one table line per parameter) and configureLine each
+ # one by updating xmin and xmax to the same values as group `gtype`
+ line = len(list(self.parameters))
+ for param in newparam:
+ self.newParameterLine(param, line)
+ line += 1
+ for param in newparam:
+ self.configureLine(name=param, group=newg, xmin=xmin, xmax=xmax)
+
+ def freeRestOfGroup(self, workparam):
+ """Set ``code`` to ``"FREE"`` for all fit parameters belonging to
+ the same group as ``workparam``. This is done when the entire group
+ of parameters was previously ignored and one of them has his code
+ set to something different than ``"IGNORE"``.
+
+ :param workparam: Fit parameter name
+ """
+ if workparam in self.parameters.keys():
+ group = int(float(str(self.parameters[workparam]['group'])))
+ for param in self.parameters:
+ if param != workparam and\
+ group == int(float(str(self.parameters[param]['group']))):
+ self.configureLine(name=param,
+ code='FREE',
+ cons1=0,
+ cons2=0,
+ val1='',
+ val2='')
+
+ def getRelatedCandidates(self, workparam):
+ """If fit parameter ``workparam`` has a constraint that involves other
+ fit parameters, find possible candidates and try to guess which one
+ is the most likely.
+
+ :param workparam: Fit parameter name
+ :return: (best_candidate, possible_candidates) tuple
+ :rtype: (str, list[str])
+ """
+ candidates = []
+ for param_name in self.parameters:
+ if param_name != workparam:
+ # ignore parameters that are fixed by a constraint
+ if str(self.parameters[param_name]['code']) not in\
+ ['IGNORE', 'FACTOR', 'DELTA', 'SUM']:
+ candidates.append(param_name)
+ # take the previous one (before code cell changed) if possible
+ if str(self.parameters[workparam]['relatedto']) in candidates:
+ best = str(self.parameters[workparam]['relatedto'])
+ return best, candidates
+ # take the first with same base name (after removing numbers)
+ for param_name in candidates:
+ basename = param_name.rstrip("0123456789")
+ try:
+ pos = workparam.index(basename)
+ if pos == 0:
+ best = param_name
+ return best, candidates
+ except ValueError:
+ pass
+ # take the first
+ return candidates[0], candidates
+
+ def setReadOnly(self, parameter, fields):
+ """Make table cells read-only by setting it's flags and omitting
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+ self.setField(parameter, fields, editflags)
+
+ def setReadWrite(self, parameter, fields):
+ """Make table cells read-write by setting it's flags including
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable |\
+ qt.Qt.ItemIsEnabled |\
+ qt.Qt.ItemIsEditable
+ self.setField(parameter, fields, editflags)
+
+ def setField(self, parameter, fields, edit_flags):
+ """Set text and flags in a table cell.
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ :param edit_flags: Flag combination, e.g::
+
+ qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable
+ """
+ if isinstance(parameter, list) or \
+ isinstance(parameter, tuple):
+ paramlist = parameter
+ else:
+ paramlist = [parameter]
+ if isinstance(fields, list) or \
+ isinstance(fields, tuple):
+ fieldlist = fields
+ else:
+ fieldlist = [fields]
+
+ # Set _configuring flag to ignore cellChanged signals in
+ # self.onCellChanged
+ _oldvalue = self.__configuring
+ self.__configuring = True
+
+ # 2D loop through parameter list and field list
+ # to update their cells
+ for param in paramlist:
+ row = list(self.parameters.keys()).index(param)
+ for field in fieldlist:
+ col = self.columnIndexByField(field)
+ if field != 'code':
+ key = field + "_item"
+ item = self.item(row, col)
+ if item is None:
+ item = qt.QTableWidgetItem()
+ item.setText(self.parameters[param][field])
+ self.setItem(row, col, item)
+ else:
+ item.setText(self.parameters[param][field])
+ self.parameters[param][key] = item
+ item.setFlags(edit_flags)
+
+ # Restore previous _configuring flag
+ self.__configuring = _oldvalue
+
+ def configureLine(self, name, code=None, val1=None, val2=None,
+ sigma=None, estimation=None, fitresult=None,
+ group=None, xmin=None, xmax=None, relatedto=None,
+ cons1=None, cons2=None):
+ """This function updates values in a line of the table
+
+ :param name: Name of the parameter (serves as unique identifier for
+ a line).
+ :param code: Constraint code *FREE, FIXED, POSITIVE, DELTA, FACTOR,
+ SUM, QUOTED, IGNORE*
+ :param val1: Constraint 1 (can be the index or name of another
+ parameter for code *DELTA, FACTOR, SUM*, or a min value
+ for code *QUOTED*)
+ :param val2: Constraint 2
+ :param sigma: Standard deviation for a fit parameter
+ :param estimation: Estimated initial value for a fit parameter (used
+ as input to iterative fit)
+ :param fitresult: Final result of fit
+ :param group: Group number of a fit parameter (peak number when doing
+ multi-peak fitting, as each peak corresponds to a group
+ of several consecutive parameters)
+ :param xmin:
+ :param xmax:
+ :param relatedto: Index or name of another fit parameter
+ to which this parameter is related to (constraints)
+ :param cons1: similar meaning to ``val1``, but is always a number
+ :param cons2: similar meaning to ``val2``, but is always a number
+ :return:
+ """
+ paramlist = list(self.parameters.keys())
+
+ if name not in self.parameters:
+ raise KeyError("'%s' is not in the parameter list" % name)
+
+ # update code first, if specified
+ if code is not None:
+ code = str(code)
+ self.parameters[name]['code'] = code
+ # update combobox
+ index = self.parameters[name]['code_item'].findText(code)
+ self.parameters[name]['code_item'].setCurrentIndex(index)
+ else:
+ # set code to previous value, used later for setting val1 val2
+ code = self.parameters[name]['code']
+
+ # val1 and sigma have special formats
+ if val1 is not None:
+ fmt = None if self.parameters[name]['code'] in\
+ ['DELTA', 'FACTOR', 'SUM'] else "%8g"
+ self._updateField(name, "val1", val1, fmat=fmt)
+
+ if sigma is not None:
+ self._updateField(name, "sigma", sigma, fmat="%6.3g")
+
+ # other fields are formatted as "%8g"
+ keys_params = (("val2", val2), ("estimation", estimation),
+ ("fitresult", fitresult))
+ for key, value in keys_params:
+ if value is not None:
+ self._updateField(name, key, value, fmat="%8g")
+
+ # the rest of the parameters are treated as strings and don't need
+ # validation
+ keys_params = (("group", group), ("xmin", xmin),
+ ("xmax", xmax), ("relatedto", relatedto),
+ ("cons1", cons1), ("cons2", cons2))
+ for key, value in keys_params:
+ if value is not None:
+ self.parameters[name][key] = str(value)
+
+ # val1 and val2 have different meanings depending on the code
+ if code == 'QUOTED':
+ if val1 is not None:
+ self.parameters[name]['vmin'] = self.parameters[name]['val1']
+ else:
+ self.parameters[name]['val1'] = self.parameters[name]['vmin']
+ if val2 is not None:
+ self.parameters[name]['vmax'] = self.parameters[name]['val2']
+ else:
+ self.parameters[name]['val2'] = self.parameters[name]['vmax']
+
+ # cons1 and cons2 are scalar representations of val1 and val2
+ self.parameters[name]['cons1'] =\
+ float_else_zero(self.parameters[name]['val1'])
+ self.parameters[name]['cons2'] =\
+ float_else_zero(self.parameters[name]['val2'])
+
+ # cons1, cons2 = min(val1, val2), max(val1, val2)
+ if self.parameters[name]['cons1'] > self.parameters[name]['cons2']:
+ self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\
+ self.parameters[name]['cons2'], self.parameters[name]['cons1']
+
+ elif code in ['DELTA', 'SUM', 'FACTOR']:
+ # For these codes, val1 is the fit parameter name on which the
+ # constraint depends
+ if val1 is not None and val1 in paramlist:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif val1 is not None:
+ # val1 could be the index of the fit parameter
+ try:
+ self.parameters[name]['relatedto'] = paramlist[int(val1)]
+ except ValueError:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif relatedto is not None:
+ # code changed, val1 not specified but relatedto specified:
+ # set val1 to relatedto (pre-fill best guess)
+ self.parameters[name]["val1"] = relatedto
+
+ # update fields "delta", "sum" or "factor"
+ key = code.lower()
+ self.parameters[name][key] = self.parameters[name]["val2"]
+
+ # FIXME: val1 is sometimes specified as an index rather than a param name
+ self.parameters[name]['val1'] = self.parameters[name]['relatedto']
+
+ # cons1 is the index of the fit parameter in the ordered dictionary
+ if self.parameters[name]['val1'] in paramlist:
+ self.parameters[name]['cons1'] =\
+ paramlist.index(self.parameters[name]['val1'])
+
+ # cons2 is the constraint value (factor, delta or sum)
+ try:
+ self.parameters[name]['cons2'] =\
+ float(str(self.parameters[name]['val2']))
+ except ValueError:
+ self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0
+
+ elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.parameters[name]['val1'] = ""
+ self.parameters[name]['val2'] = ""
+ self.parameters[name]['cons1'] = 0
+ self.parameters[name]['cons2'] = 0
+
+ self._updateCellRWFlags(name, code)
+
+ def _updateField(self, name, field, value, fmat=None):
+ """Update field in ``self.parameters`` dictionary, if the new value
+ is valid.
+
+ :param name: Fit parameter name
+ :param field: Field name
+ :param value: New value to assign
+ :type value: String
+ :param fmat: Format string (e.g. "%8g") to be applied if value represents
+ a scalar. If ``None``, format is not modified. If ``value`` is an
+ empty string, ``fmat`` is ignored.
+ """
+ if value is not None:
+ oldvalue = self.parameters[name][field]
+ if fmat is not None:
+ newvalue = fmat % float(value) if value != "" else ""
+ else:
+ newvalue = value
+ self.parameters[name][field] = newvalue if\
+ self.validate(name, field, oldvalue, newvalue) else\
+ oldvalue
+
+ def _updateCellRWFlags(self, name, code=None):
+ """Set read-only or read-write flags in a row,
+ depending on the constraint code
+
+ :param name: Fit parameter name identifying the row
+ :param code: Constraint code, in `'FREE', 'POSITIVE', 'IGNORE',`
+ `'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'`
+ :return:
+ """
+ if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.setReadWrite(name, 'estimation')
+ self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2'])
+ else:
+ self.setReadWrite(name, ['estimation', 'val1', 'val2'])
+ self.setReadOnly(name, ['fitresult', 'sigma'])
+
+ def getEstimationConstraints(self, param):
+ """
+ Return tuple ``(estimation, constraints)`` where ``estimation`` is the
+ value in the ``estimate`` field and ``constraints`` are the relevant
+ constraints according to the active code
+ """
+ estimation = None
+ constraints = None
+ if param in self.parameters.keys():
+ buf = str(self.parameters[param]['estimation'])
+ if len(buf):
+ estimation = float(buf)
+ else:
+ estimation = 0
+ if str(self.parameters[param]['code']) in self.code_options:
+ code = self.code_options.index(
+ str(self.parameters[param]['code']))
+ else:
+ code = str(self.parameters[param]['code'])
+ cons1 = self.parameters[param]['cons1']
+ cons2 = self.parameters[param]['cons2']
+ constraints = [code, cons1, cons2]
+ return estimation, constraints
+
+
+def main(args):
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ try:
+ from PyMca5 import PyMcaDataDir
+ except ImportError:
+ raise ImportError("This demo requires PyMca data. Install PyMca5.")
+ import numpy
+ import os
+ app = qt.QApplication(args)
+ tab = Parameters(paramlist=['Height', 'Position', 'FWHM'])
+ tab.showGrid()
+ tab.configureLine(name='Height', estimation='1234', group=0)
+ tab.configureLine(name='Position', code='FIXED', group=1)
+ tab.configureLine(name='FWHM', group=1)
+
+ y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR,
+ "XRFSpectrum.mca")) # FIXME
+
+ x = numpy.arange(len(y)) * 0.0502883 - 0.492773
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y, xmin=20, xmax=150)
+
+ fit.loadtheories(fittheories)
+
+ fit.settheory('ahypermet')
+ fit.configure(Yscaling=1.,
+ PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ FwhmPoints=16,
+ QuotedPositionFlag=1,
+ HypermetTails=1)
+ fit.setbackground('Linear')
+ fit.estimate()
+ fit.runfit()
+ tab.fillFromFit(fit.fit_results)
+ tab.show()
+ app.exec_()
+
+if __name__ == "__main__":
+ main(sys.argv)
diff --git a/silx/gui/fit/__init__.py b/silx/gui/fit/__init__.py
new file mode 100644
index 0000000..e4fd3ab
--- /dev/null
+++ b/silx/gui/fit/__init__.py
@@ -0,0 +1,28 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "07/07/2016"
+
+from .FitWidget import FitWidget
diff --git a/silx/gui/fit/setup.py b/silx/gui/fit/setup.py
new file mode 100644
index 0000000..6672363
--- /dev/null
+++ b/silx/gui/fit/setup.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/07/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('fit', parent_package, top_path)
+ config.add_subpackage('test')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/fit/test/__init__.py b/silx/gui/fit/test/__init__.py
new file mode 100644
index 0000000..2236d64
--- /dev/null
+++ b/silx/gui/fit/test/__init__.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from .testFitWidget import suite as testFitWidgetSuite
+from .testFitConfig import suite as testFitConfigSuite
+from .testBackgroundWidget import suite as testBackgroundWidgetSuite
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/07/2016"
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTests(
+ [testFitWidgetSuite(),
+ testFitConfigSuite(),
+ testBackgroundWidgetSuite()])
+ return test_suite
diff --git a/silx/gui/fit/test/testBackgroundWidget.py b/silx/gui/fit/test/testBackgroundWidget.py
new file mode 100644
index 0000000..2e366e4
--- /dev/null
+++ b/silx/gui/fit/test/testBackgroundWidget.py
@@ -0,0 +1,83 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from ...test.utils import TestCaseQt
+
+from .. import BackgroundWidget
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestBackgroundWidget(TestCaseQt):
+ def setUp(self):
+ super(TestBackgroundWidget, self).setUp()
+ self.bgdialog = BackgroundWidget.BackgroundDialog()
+ self.bgdialog.setData(list([0, 1, 2, 3]),
+ list([0, 1, 4, 8]))
+ self.qWaitForWindowExposed(self.bgdialog)
+
+ def tearDown(self):
+ del self.bgdialog
+ super(TestBackgroundWidget, self).tearDown()
+
+ def testShow(self):
+ self.bgdialog.show()
+ self.bgdialog.hide()
+
+ def testAccept(self):
+ self.bgdialog.accept()
+ self.assertTrue(self.bgdialog.result())
+
+ def testReject(self):
+ self.bgdialog.reject()
+ self.assertFalse(self.bgdialog.result())
+
+ def testDefaultOutput(self):
+ self.bgdialog.accept()
+ output = self.bgdialog.output
+
+ for key in ["algorithm", "StripThreshold", "SnipWidth",
+ "StripIterations", "StripWidth", "SmoothingFlag",
+ "SmoothingWidth", "AnchorsFlag", "AnchorsList"]:
+ self.assertIn(key, output)
+
+ self.assertFalse(output["AnchorsFlag"])
+ self.assertEqual(output["StripWidth"], 1)
+ self.assertEqual(output["SmoothingFlag"], False)
+ self.assertEqual(output["SmoothingWidth"], 3)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestBackgroundWidget))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/fit/test/testFitConfig.py b/silx/gui/fit/test/testFitConfig.py
new file mode 100644
index 0000000..eea35cc
--- /dev/null
+++ b/silx/gui/fit/test/testFitConfig.py
@@ -0,0 +1,95 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitConfig`"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from ...test.utils import TestCaseQt
+from .. import FitConfig
+
+
+class TestFitConfig(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitConfig, self).setUp()
+ self.fit_config = FitConfig.getFitConfigDialog(modal=False)
+ self.qWaitForWindowExposed(self.fit_config)
+
+ def tearDown(self):
+ del self.fit_config
+ super(TestFitConfig, self).tearDown()
+
+ def testShow(self):
+ self.fit_config.show()
+ self.fit_config.hide()
+
+ def testAccept(self):
+ self.fit_config.accept()
+ self.assertTrue(self.fit_config.result())
+
+ def testReject(self):
+ self.fit_config.reject()
+ self.assertFalse(self.fit_config.result())
+
+ def testDefaultOutput(self):
+ self.fit_config.accept()
+ output = self.fit_config.output
+
+ for key in ["AutoFwhm",
+ "PositiveHeightAreaFlag",
+ "QuotedPositionFlag",
+ "PositiveFwhmFlag",
+ "SameFwhmFlag",
+ "QuotedEtaFlag",
+ "NoConstraintsFlag",
+ "FwhmPoints",
+ "Sensitivity",
+ "Yscaling",
+ "ForcePeakPresence",
+ "StripBackgroundFlag",
+ "StripWidth",
+ "StripIterations",
+ "StripThreshold",
+ "SmoothingFlag"]:
+ self.assertIn(key, output)
+
+ self.assertTrue(output["AutoFwhm"])
+ self.assertEqual(output["StripWidth"], 2)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestFitConfig))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/fit/test/testFitWidget.py b/silx/gui/fit/test/testFitWidget.py
new file mode 100644
index 0000000..d542fd0
--- /dev/null
+++ b/silx/gui/fit/test/testFitWidget.py
@@ -0,0 +1,135 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitWidget`"""
+
+import unittest
+
+from ...test.utils import TestCaseQt
+
+from ... import qt
+from .. import FitWidget
+
+from ....math.fit.fittheory import FitTheory
+from ....math.fit.fitmanager import FitManager
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestFitWidget(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitWidget, self).setUp()
+ self.fit_widget = FitWidget()
+ self.fit_widget.show()
+ self.qWaitForWindowExposed(self.fit_widget)
+
+ def tearDown(self):
+ self.fit_widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.fit_widget.close()
+ del self.fit_widget
+ super(TestFitWidget, self).tearDown()
+
+ def testShow(self):
+ pass
+
+ def testInteract(self):
+ self.mouseClick(self.fit_widget, qt.Qt.LeftButton)
+ self.keyClick(self.fit_widget, qt.Qt.Key_Enter)
+ self.qapp.processEvents()
+
+ def testCustomConfigWidget(self):
+ class CustomConfigWidget(qt.QDialog):
+ def __init__(self):
+ qt.QDialog.__init__(self)
+ self.setModal(True)
+ self.ok = qt.QPushButton("ok", self)
+ self.ok.clicked.connect(self.accept)
+ cancel = qt.QPushButton("cancel", self)
+ cancel.clicked.connect(self.reject)
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.ok)
+ layout.addWidget(cancel)
+ self.output = {"hello": "world"}
+
+ def fitfun(x, a, b):
+ return a * x + b
+
+ x = list(range(0, 100))
+ y = [fitfun(x_, 2, 3) for x_ in x]
+
+ def conf(**kw):
+ return {"spam": "eggs",
+ "hello": "world!"}
+
+ theory = FitTheory(
+ function=fitfun,
+ parameters=["a", "b"],
+ configure=conf)
+
+ fitmngr = FitManager()
+ fitmngr.setdata(x, y)
+ fitmngr.addtheory("foo", theory)
+ fitmngr.addtheory("bar", theory)
+ fitmngr.addbgtheory("spam", theory)
+
+ fw = FitWidget(fitmngr=fitmngr)
+ fw.associateConfigDialog("spam", CustomConfigWidget(),
+ theory_is_background=True)
+ fw.associateConfigDialog("foo", CustomConfigWidget())
+ fw.show()
+ self.qWaitForWindowExposed(fw)
+
+ fw.bgconfigdialogs["spam"].accept()
+ self.assertTrue(fw.bgconfigdialogs["spam"].result())
+
+ self.assertEqual(fw.bgconfigdialogs["spam"].output,
+ {"hello": "world"})
+
+ fw.bgconfigdialogs["spam"].reject()
+ self.assertFalse(fw.bgconfigdialogs["spam"].result())
+
+ fw.configdialogs["foo"].accept()
+ self.assertTrue(fw.configdialogs["foo"].result())
+
+ # todo: figure out how to click fw.configdialog.ok to close dialog
+ # open dialog
+ # self.mouseClick(fw.guiConfig.FunConfigureButton, qt.Qt.LeftButton)
+ # clove dialog
+ # self.mouseClick(fw.configdialogs["foo"].ok, qt.Qt.LeftButton)
+ # self.qapp.processEvents()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestFitWidget))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/hdf5/Hdf5HeaderView.py b/silx/gui/hdf5/Hdf5HeaderView.py
new file mode 100644
index 0000000..5912230
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5HeaderView.py
@@ -0,0 +1,192 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "08/11/2016"
+
+
+from .. import qt
+
+QTVERSION = qt.qVersion()
+
+
+class Hdf5HeaderView(qt.QHeaderView):
+ """
+ Default HDF5 header
+
+ Manage auto-resize and context menu to display/hide columns
+ """
+
+ def __init__(self, orientation, parent=None):
+ """
+ Constructor
+
+ :param orientation qt.Qt.Orientation: Orientation of the header
+ :param parent qt.QWidget: Parent of the widget
+ """
+ super(Hdf5HeaderView, self).__init__(orientation, parent)
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self.__createContextMenu)
+
+ # default initialization done by QTreeView for it's own header
+ if QTVERSION < "5.0":
+ self.setClickable(True)
+ self.setMovable(True)
+ else:
+ self.setSectionsClickable(True)
+ self.setSectionsMovable(True)
+ self.setDefaultAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ self.setStretchLastSection(True)
+
+ self.__auto_resize = True
+ self.__hide_columns_popup = True
+
+ def setModel(self, model):
+ """Override model to configure view when a model is expected
+
+ `qt.QHeaderView.setResizeMode` expect already existing columns
+ to work.
+
+ :param model qt.QAbstractItemModel: A model
+ """
+ super(Hdf5HeaderView, self).setModel(model)
+ self.__updateAutoResize()
+
+ def __updateAutoResize(self):
+ """Update the view according to the state of the auto-resize"""
+ if QTVERSION < "5.0":
+ setResizeMode = self.setResizeMode
+ else:
+ setResizeMode = self.setSectionResizeMode
+
+ if self.__auto_resize:
+ setResizeMode(0, qt.QHeaderView.ResizeToContents)
+ setResizeMode(1, qt.QHeaderView.ResizeToContents)
+ setResizeMode(2, qt.QHeaderView.ResizeToContents)
+ setResizeMode(3, qt.QHeaderView.Interactive)
+ setResizeMode(4, qt.QHeaderView.Interactive)
+ setResizeMode(5, qt.QHeaderView.ResizeToContents)
+ else:
+ setResizeMode(0, qt.QHeaderView.Interactive)
+ setResizeMode(1, qt.QHeaderView.Interactive)
+ setResizeMode(2, qt.QHeaderView.Interactive)
+ setResizeMode(3, qt.QHeaderView.Interactive)
+ setResizeMode(4, qt.QHeaderView.Interactive)
+ setResizeMode(5, qt.QHeaderView.Interactive)
+
+ def setAutoResizeColumns(self, autoResize):
+ """Enable/disable auto-resize. When auto-resized, the header take care
+ of the content of the column to set fixed size of some of them, or to
+ auto fix the size according to the content.
+
+ :param autoResize bool: Enable/disable auto-resize
+ """
+ if self.__auto_resize == autoResize:
+ return
+ self.__auto_resize = autoResize
+ self.__updateAutoResize()
+
+ def hasAutoResizeColumns(self):
+ """Is auto-resize enabled.
+
+ :rtype: bool
+ """
+ return self.__auto_resize
+
+ autoResizeColumns = qt.Property(bool, hasAutoResizeColumns, setAutoResizeColumns)
+ """Property to enable/disable auto-resize."""
+
+ def setEnableHideColumnsPopup(self, enablePopup):
+ """Enable/disable a popup to allow to hide/show each column of the
+ model.
+
+ :param bool enablePopup: Enable/disable popup to hide/show columns
+ """
+ self.__hide_columns_popup = enablePopup
+
+ def hasHideColumnsPopup(self):
+ """Is popup to hide/show columns is enabled.
+
+ :rtype: bool
+ """
+ return self.__hide_columns_popup
+
+ enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns)
+ """Property to enable/disable popup allowing to hide/show columns."""
+
+ def __genHideSectionEvent(self, column):
+ """Generate a callback which change the column visibility according to
+ the event parameter
+
+ :param int column: logical id of the column
+ :rtype: callable
+ """
+ return lambda checked: self.setSectionHidden(column, not checked)
+
+ def __createContextMenu(self, pos):
+ """Callback to create and display a context menu
+
+ :param pos qt.QPoint: Requested position for the context menu
+ """
+ if not self.__hide_columns_popup:
+ return
+
+ model = self.model()
+ if model.columnCount() > 1:
+ menu = qt.QMenu(self)
+ menu.setTitle("Display/hide columns")
+
+ action = qt.QAction("Display/hide column", self)
+ action.setEnabled(False)
+ menu.addAction(action)
+
+ for column in range(model.columnCount()):
+ if column == 0:
+ # skip the main column
+ continue
+ text = model.headerData(column, qt.Qt.Horizontal, qt.Qt.DisplayRole)
+ action = qt.QAction("%s displayed" % text, self)
+ action.setCheckable(True)
+ action.setChecked(not self.isSectionHidden(column))
+ action.toggled.connect(self.__genHideSectionEvent(column))
+ menu.addAction(action)
+
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def setSections(self, logicalIndexes):
+ """
+ Defines order of visible sections by logical indexes.
+
+ Use `Hdf5TreeModel.NAME_COLUMN` to set the list.
+
+ :param list logicalIndexes: List of logical indexes to display
+ """
+ for pos, column_id in enumerate(logicalIndexes):
+ current_pos = self.visualIndex(column_id)
+ self.moveSection(current_pos, pos)
+ self.setSectionHidden(column_id, False)
+ for column_id in set(range(self.model().columnCount())) - set(logicalIndexes):
+ self.setSectionHidden(column_id, True)
diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py
new file mode 100644
index 0000000..40793a4
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5Item.py
@@ -0,0 +1,421 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/01/2017"
+
+
+import numpy
+import logging
+import collections
+from .. import qt
+from .. import icons
+from . import _utils
+from .Hdf5Node import Hdf5Node
+import silx.io.utils
+from silx.gui.data.TextFormatter import TextFormatter
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import h5py
+except ImportError as e:
+ _logger.error("Module %s requires h5py", __name__)
+ raise e
+
+_formatter = TextFormatter()
+
+
+class Hdf5Item(Hdf5Node):
+ """Subclass of :class:`qt.QStandardItem` to represent an HDF5-like
+ item (dataset, file, group or link) as an element of a HDF5-like
+ tree structure.
+ """
+
+ def __init__(self, text, obj, parent, key=None, h5pyClass=None, isBroken=False, populateAll=False):
+ """
+ :param str text: text displayed
+ :param object obj: Pointer to h5py data. See the `obj` attribute.
+ """
+ self.__obj = obj
+ self.__key = key
+ self.__h5pyClass = h5pyClass
+ self.__isBroken = isBroken
+ self.__error = None
+ self.__text = text
+ Hdf5Node.__init__(self, parent, populateAll=populateAll)
+
+ @property
+ def obj(self):
+ if self.__key:
+ self.__initH5pyObject()
+ return self.__obj
+
+ @property
+ def basename(self):
+ return self.__text
+
+ @property
+ def h5pyClass(self):
+ """Returns the class of the stored object.
+
+ When the object is in lazy loading, this method should be able to
+ return the type of the futrue loaded object. It allows to delay the
+ real load of the object.
+
+ :rtype: h5py.File or h5py.Dataset or h5py.Group
+ """
+ if self.__h5pyClass is None:
+ self.__h5pyClass = silx.io.utils.get_h5py_class(self.obj)
+ return self.__h5pyClass
+
+ def isGroupObj(self):
+ """Returns true if the stored HDF5 object is a group (contains sub
+ groups or datasets).
+
+ :rtype: bool
+ """
+ return issubclass(self.h5pyClass, h5py.Group)
+
+ def isBrokenObj(self):
+ """Returns true if the stored HDF5 object is broken.
+
+ The stored object is then an h5py link (external or not) which point
+ to nowhere (tbhe external file is not here, the expected dataset is
+ still not on the file...)
+
+ :rtype: bool
+ """
+ return self.__isBroken
+
+ def _expectedChildCount(self):
+ if self.isGroupObj():
+ return len(self.obj)
+ return 0
+
+ def __initH5pyObject(self):
+ """Lazy load of the HDF5 node. It is reached from the parent node
+ with the key of the node."""
+ parent_obj = self.parent.obj
+
+ try:
+ obj = parent_obj.get(self.__key)
+ except Exception as e:
+ _logger.debug("Internal h5py error", exc_info=True)
+ try:
+ self.__obj = parent_obj.get(self.__key, getlink=True)
+ except Exception:
+ self.__obj = None
+ self.__error = e.args[0]
+ self.__isBroken = True
+ else:
+ if obj is None:
+ # that's a broken link
+ self.__obj = parent_obj.get(self.__key, getlink=True)
+
+ # TODO monkey-patch file (ask that in h5py for consistency)
+ if not hasattr(self.__obj, "name"):
+ parent_name = parent_obj.name
+ if parent_name == "/":
+ self.__obj.name = "/" + self.__key
+ else:
+ self.__obj.name = parent_name + "/" + self.__key
+ # TODO monkey-patch file (ask that in h5py for consistency)
+ if not hasattr(self.__obj, "file"):
+ self.__obj.file = parent_obj.file
+
+ if isinstance(self.__obj, h5py.ExternalLink):
+ message = "External link broken. Path %s::%s does not exist" % (self.__obj.filename, self.__obj.path)
+ elif isinstance(self.__obj, h5py.SoftLink):
+ message = "Soft link broken. Path %s does not exist" % (self.__obj.path)
+ else:
+ name = self.obj.__class__.__name__.split(".")[-1].capitalize()
+ message = "%s broken" % (name)
+ self.__error = message
+ self.__isBroken = True
+ else:
+ self.__obj = obj
+
+ self.__key = None
+
+ def _populateChild(self, populateAll=False):
+ if self.isGroupObj():
+ for name in self.obj:
+ try:
+ class_ = self.obj.get(name, getclass=True)
+ has_error = False
+ except Exception as e:
+ _logger.error("Internal h5py error", exc_info=True)
+ try:
+ class_ = self.obj.get(name, getclass=True, getlink=True)
+ except Exception as e:
+ class_ = h5py.HardLink
+ has_error = True
+ item = Hdf5Item(text=name, obj=None, parent=self, key=name, h5pyClass=class_, isBroken=has_error)
+ self.appendChild(item)
+
+ def hasChildren(self):
+ """Retuens true of this node have chrild.
+
+ :rtype: bool
+ """
+ if not self.isGroupObj():
+ return False
+ return Hdf5Node.hasChildren(self)
+
+ def _getDefaultIcon(self):
+ """Returns the icon displayed by the main column.
+
+ :rtype: qt.QIcon
+ """
+ style = qt.QApplication.style()
+ if self.__isBroken:
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ return icon
+ class_ = self.h5pyClass
+ if issubclass(class_, h5py.File):
+ return style.standardIcon(qt.QStyle.SP_FileIcon)
+ elif issubclass(class_, h5py.Group):
+ return style.standardIcon(qt.QStyle.SP_DirIcon)
+ elif issubclass(class_, h5py.SoftLink):
+ return style.standardIcon(qt.QStyle.SP_DirLinkIcon)
+ elif issubclass(class_, h5py.ExternalLink):
+ return style.standardIcon(qt.QStyle.SP_FileLinkIcon)
+ elif issubclass(class_, h5py.Dataset):
+ if len(self.obj.shape) < 4:
+ name = "item-%ddim" % len(self.obj.shape)
+ else:
+ name = "item-ndim"
+ if str(self.obj.dtype) == "object":
+ name = "item-object"
+ icon = icons.getQIcon(name)
+ return icon
+ return None
+
+ def _humanReadableShape(self, dataset):
+ if dataset.shape == tuple():
+ return "scalar"
+ shape = [str(i) for i in dataset.shape]
+ text = u" \u00D7 ".join(shape)
+ return text
+
+ def _humanReadableValue(self, dataset):
+ if dataset.shape == tuple():
+ numpy_object = dataset[()]
+ text = _formatter.toString(numpy_object)
+ else:
+ if dataset.size < 5 and dataset.compression is None:
+ numpy_object = dataset[0:5]
+ text = _formatter.toString(numpy_object)
+ else:
+ dimension = len(dataset.shape)
+ if dataset.compression is not None:
+ text = "Compressed %dD data" % dimension
+ else:
+ text = "%dD data" % dimension
+ return text
+
+ def _humanReadableDType(self, dtype, full=False):
+ if dtype.type == numpy.string_:
+ text = "string"
+ elif dtype.type == numpy.unicode_:
+ text = "string"
+ elif dtype.type == numpy.object_:
+ text = "object"
+ elif dtype.type == numpy.bool_:
+ text = "bool"
+ elif dtype.type == numpy.void:
+ if dtype.fields is None:
+ text = "raw"
+ else:
+ if not full:
+ text = "compound"
+ else:
+ compound = [d[0] for d in dtype.fields.values()]
+ compound = [self._humanReadableDType(d) for d in compound]
+ text = "compound(%s)" % ", ".join(compound)
+ else:
+ text = str(dtype)
+ return text
+
+ def _humanReadableType(self, dataset, full=False):
+ return self._humanReadableDType(dataset.dtype, full)
+
+ def _setTooltipAttributes(self, attributeDict):
+ """
+ Add key/value attributes that will be displayed in the item tooltip
+
+ :param Dict[str,str] attributeDict: Key/value attributes
+ """
+ if issubclass(self.h5pyClass, h5py.Dataset):
+ attributeDict["Title"] = "HDF5 Dataset"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Shape"] = self._humanReadableShape(self.obj)
+ attributeDict["Value"] = self._humanReadableValue(self.obj)
+ attributeDict["Data type"] = self._humanReadableType(self.obj, full=True)
+ elif issubclass(self.h5pyClass, h5py.Group):
+ attributeDict["Title"] = "HDF5 Group"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ elif issubclass(self.h5pyClass, h5py.File):
+ attributeDict["Title"] = "HDF5 File"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = "/"
+ elif isinstance(self.obj, h5py.ExternalLink):
+ attributeDict["Title"] = "HDF5 External Link"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Linked path"] = self.obj.path
+ attributeDict["Linked file"] = self.obj.filename
+ elif isinstance(self.obj, h5py.SoftLink):
+ attributeDict["Title"] = "HDF5 Soft Link"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Linked path"] = self.obj.path
+ else:
+ pass
+
+ def _getDefaultTooltip(self):
+ """Returns the default tooltip
+
+ :rtype: str
+ """
+ if self.__error is not None:
+ self.obj # lazy loading of the object
+ return self.__error
+
+ attrs = collections.OrderedDict()
+ self._setTooltipAttributes(attrs)
+
+ title = attrs.pop("Title", None)
+ if len(attrs) > 0:
+ tooltip = _utils.htmlFromDict(attrs, title=title)
+ else:
+ tooltip = ""
+
+ return tooltip
+
+ def dataName(self, role):
+ """Data for the name column"""
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return self.__text
+ if role == qt.Qt.DecorationRole:
+ return self._getDefaultIcon()
+ if role == qt.Qt.ToolTipRole:
+ return self._getDefaultTooltip()
+ return None
+
+ def dataType(self, role):
+ """Data for the type column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ class_ = self.h5pyClass
+ if issubclass(class_, h5py.Dataset):
+ text = self._humanReadableType(self.obj)
+ else:
+ text = ""
+ return text
+
+ return None
+
+ def dataShape(self, role):
+ """Data for the shape column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ class_ = self.h5pyClass
+ if not issubclass(class_, h5py.Dataset):
+ return ""
+ return self._humanReadableShape(self.obj)
+ return None
+
+ def dataValue(self, role):
+ """Data for the value column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ if not issubclass(self.h5pyClass, h5py.Dataset):
+ return ""
+ return self._humanReadableValue(self.obj)
+ return None
+
+ def dataDescription(self, role):
+ """Data for the description column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__isBroken:
+ self.obj # lazy loading of the object
+ return self.__error
+ if "desc" in self.obj.attrs:
+ text = self.obj.attrs["desc"]
+ else:
+ return ""
+ return text
+ if role == qt.Qt.ToolTipRole:
+ if self.__error is not None:
+ self.obj # lazy loading of the object
+ self.__initH5pyObject()
+ return self.__error
+ if "desc" in self.obj.attrs:
+ text = self.obj.attrs["desc"]
+ else:
+ return ""
+ return "Description: %s" % text
+ return None
+
+ def dataNode(self, role):
+ """Data for the node column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ class_ = self.h5pyClass
+ text = class_.__name__.split(".")[-1]
+ return text
+ if role == qt.Qt.ToolTipRole:
+ class_ = self.h5pyClass
+ return "Class name: %s" % self.__class__
+ return None
diff --git a/silx/gui/hdf5/Hdf5LoadingItem.py b/silx/gui/hdf5/Hdf5LoadingItem.py
new file mode 100644
index 0000000..4467366
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5LoadingItem.py
@@ -0,0 +1,68 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/09/2016"
+
+
+from .. import qt
+from .Hdf5Node import Hdf5Node
+
+
+class Hdf5LoadingItem(Hdf5Node):
+ """Item displayed when an Hdf5Node is loading.
+
+ At the end of the loading this item is replaced by the loaded one.
+ """
+
+ def __init__(self, text, parent, animatedIcon):
+ """Constructor"""
+ Hdf5Node.__init__(self, parent)
+ self.__text = text
+ self.__animatedIcon = animatedIcon
+ self.__animatedIcon.register(self)
+
+ @property
+ def obj(self):
+ return None
+
+ def dataName(self, role):
+ if role == qt.Qt.DecorationRole:
+ return self.__animatedIcon.currentIcon()
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return self.__text
+ return None
+
+ def dataDescription(self, role):
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return "Loading..."
+ return None
diff --git a/silx/gui/hdf5/Hdf5Node.py b/silx/gui/hdf5/Hdf5Node.py
new file mode 100644
index 0000000..31bb097
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5Node.py
@@ -0,0 +1,210 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/09/2016"
+
+
+class Hdf5Node(object):
+ """Abstract tree node
+
+ It provides link to the childs and to the parents, and a link to an
+ external object.
+ """
+ def __init__(self, parent=None, populateAll=False):
+ """
+ Constructor
+
+ :param Hdf5Node parent: Parent of the node, if exists, else None
+ :param bool populateAll: If true, populate all the tree node. Else
+ everything is lazy loaded.
+ """
+ self.__child = None
+ self.__parent = parent
+ if populateAll:
+ self.__child = []
+ self._populateChild(populateAll=True)
+
+ @property
+ def parent(self):
+ """Parent of the node, or None if the node is a root
+
+ :rtype: Hdf5Node
+ """
+ return self.__parent
+
+ def setParent(self, parent):
+ """Redefine the parent of the node.
+
+ It does not set the node as the children of the new parent.
+
+ :param Hdf5Node parent: The new parent
+ """
+ self.__parent = parent
+
+ def appendChild(self, child):
+ """Append a child to the node.
+
+ It does not update the parent of the child.
+
+ :param Hdf5Node child: Child to append to the node.
+ """
+ self.__initChild()
+ self.__child.append(child)
+
+ def removeChildAtIndex(self, index):
+ """Remove a child at an index of the children list.
+
+ The child is removed and returned.
+
+ :param int index: Index in the child list.
+ :rtype: Hdf5Node
+ :raises: IndexError if list is empty or index is out of range.
+ """
+ self.__initChild()
+ return self.__child.pop(index)
+
+ def insertChild(self, index, child):
+ """
+ Insert a child at a specific index of the child list.
+
+ It does not update the parent of the child.
+
+ :param int index: Index in the child list.
+ :param Hdf5Node child: Child to insert in the child list.
+ """
+ self.__initChild()
+ self.__child.insert(index, child)
+
+ def indexOfChild(self, child):
+ """
+ Returns the index of the child in the child list of this node.
+
+ :param Hdf5Node child: Child to find
+ :raises: ValueError if the value is not present.
+ """
+ self.__initChild()
+ return self.__child.index(child)
+
+ def hasChildren(self):
+ """Returns true if the node contains children.
+
+ :rtype: bool
+ """
+ return self.childCount() > 0
+
+ def childCount(self):
+ """Returns the number of child in this node.
+
+ :rtype: int
+ """
+ if self.__child is not None:
+ return len(self.__child)
+ return self._expectedChildCount()
+
+ def child(self, index):
+ """Return the child at an expected index.
+
+ :param int index: Index of the child in the child list of the node
+ :rtype: Hdf5Node
+ """
+ self.__initChild()
+ return self.__child[index]
+
+ def __initChild(self):
+ """Init the child of the node in case the list was lazy loaded."""
+ if self.__child is None:
+ self.__child = []
+ self._populateChild()
+
+ def _expectedChildCount(self):
+ """Returns the expected count of children
+
+ :rtype: int
+ """
+ return 0
+
+ def _populateChild(self, populateAll=False):
+ """Recurse through an HDF5 structure to append groups an datasets
+ into the tree model.
+
+ Overwrite it to implement the initialisation of child of the node.
+ """
+ pass
+
+ def dataName(self, role):
+ """Data for the name column
+
+ Overwrite it to implement the content of the 'name' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataType(self, role):
+ """Data for the type column
+
+ Overwrite it to implement the content of the 'type' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataShape(self, role):
+ """Data for the shape column
+
+ Overwrite it to implement the content of the 'shape' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataValue(self, role):
+ """Data for the value column
+
+ Overwrite it to implement the content of the 'value' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataDescription(self, role):
+ """Data for the description column
+
+ Overwrite it to implement the content of the 'description' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataNode(self, role):
+ """Data for the node column
+
+ Overwrite it to implement the content of the 'node' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py
new file mode 100644
index 0000000..fb5de06
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5TreeModel.py
@@ -0,0 +1,581 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "19/12/2016"
+
+
+import os
+import logging
+from .. import qt
+from .. import icons
+from .Hdf5Node import Hdf5Node
+from .Hdf5Item import Hdf5Item
+from .Hdf5LoadingItem import Hdf5LoadingItem
+from . import _utils
+from ... import io as silx_io
+
+_logger = logging.getLogger(__name__)
+
+"""Helpers to take care of None objects as signal parameters.
+PySide crash if a signal with a None parameter is emitted between threads.
+"""
+if qt.BINDING == 'PySide':
+ class _NoneWraper(object):
+ pass
+ _NoneWraperInstance = _NoneWraper()
+
+ def _wrapNone(x):
+ """Wrap x if it is a None value, else returns x"""
+ if x is None:
+ return _NoneWraperInstance
+ else:
+ return x
+
+ def _unwrapNone(x):
+ """Unwrap x as a None if a None was stored by `wrapNone`, else returns
+ x"""
+ if x is _NoneWraperInstance:
+ return None
+ else:
+ return x
+else:
+ # Allow to fix None event params to avoid PySide crashes
+ def _wrapNone(x):
+ return x
+
+ def _unwrapNone(x):
+ return x
+
+
+class LoadingItemRunnable(qt.QRunnable):
+ """Runner to process item loading from a file"""
+
+ class __Signals(qt.QObject):
+ """Signal holder"""
+ itemReady = qt.Signal(object, object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, filename, item):
+ """Constructor
+
+ :param LoadingItemWorker worker: Object holding data and signals
+ """
+ super(LoadingItemRunnable, self).__init__()
+ self.filename = filename
+ self.oldItem = item
+ self.signals = self.__Signals()
+
+ def setFile(self, filename, item):
+ self.filenames.append((filename, item))
+
+ @property
+ def itemReady(self):
+ return self.signals.itemReady
+
+ @property
+ def runnerFinished(self):
+ return self.signals.runnerFinished
+
+ def __loadItemTree(self, oldItem, h5obj):
+ """Create an item tree used by the GUI from an h5py object.
+
+ :param Hdf5Node oldItem: The current item displayed the GUI
+ :param h5py.File h5obj: The h5py object to display in the GUI
+ :rtpye: Hdf5Node
+ """
+ if silx_io.is_file(h5obj):
+ text = os.path.basename(h5obj.filename)
+ else:
+ filename = os.path.basename(h5obj.file.filename)
+ path = h5obj.name
+ text = "%s::%s" % (filename, path)
+ item = Hdf5Item(text=text, obj=h5obj, parent=oldItem.parent, populateAll=True)
+ return item
+
+ @qt.Slot()
+ def run(self):
+ """Process the file loading. The worker is used as holder
+ of the data and the signal. The result is sent as a signal.
+ """
+ try:
+ h5file = silx_io.open(self.filename)
+ newItem = self.__loadItemTree(self.oldItem, h5file)
+ error = None
+ except IOError as e:
+ # Should be logged
+ error = e
+ newItem = None
+
+ # Take care of None value in case of PySide
+ newItem = _wrapNone(newItem)
+ error = _wrapNone(error)
+ self.itemReady.emit(self.oldItem, newItem, error)
+ self.runnerFinished.emit(self)
+
+ def autoDelete(self):
+ return True
+
+
+class Hdf5TreeModel(qt.QAbstractItemModel):
+ """Tree model storing a list of :class:`h5py.File` like objects.
+
+ The main column display the :class:`h5py.File` list and there hierarchy.
+ Other columns display information on node hierarchy.
+ """
+
+ H5PY_ITEM_ROLE = qt.Qt.UserRole
+ """Role to reach h5py item from an item index"""
+
+ H5PY_OBJECT_ROLE = qt.Qt.UserRole + 1
+ """Role to reach h5py object from an item index"""
+
+ USER_ROLE = qt.Qt.UserRole + 2
+ """Start of range of available user role for derivative models"""
+
+ NAME_COLUMN = 0
+ """Column id containing HDF5 node names"""
+
+ TYPE_COLUMN = 1
+ """Column id containing HDF5 dataset types"""
+
+ SHAPE_COLUMN = 2
+ """Column id containing HDF5 dataset shapes"""
+
+ VALUE_COLUMN = 3
+ """Column id containing HDF5 dataset values"""
+
+ DESCRIPTION_COLUMN = 4
+ """Column id containing HDF5 node description/title/message"""
+
+ NODE_COLUMN = 5
+ """Column id containing HDF5 node type"""
+
+ COLUMN_IDS = [
+ NAME_COLUMN,
+ TYPE_COLUMN,
+ SHAPE_COLUMN,
+ VALUE_COLUMN,
+ DESCRIPTION_COLUMN,
+ NODE_COLUMN,
+ ]
+ """List of logical columns available"""
+
+ def __init__(self, parent=None):
+ super(Hdf5TreeModel, self).__init__(parent)
+
+ self.treeView = parent
+ self.header_labels = [None] * 6
+ self.header_labels[self.NAME_COLUMN] = 'Name'
+ self.header_labels[self.TYPE_COLUMN] = 'Type'
+ self.header_labels[self.SHAPE_COLUMN] = 'Shape'
+ self.header_labels[self.VALUE_COLUMN] = 'Value'
+ self.header_labels[self.DESCRIPTION_COLUMN] = 'Description'
+ self.header_labels[self.NODE_COLUMN] = 'Node'
+
+ # Create items
+ self.__root = Hdf5Node()
+ self.__fileDropEnabled = True
+ self.__fileMoveEnabled = True
+
+ self.__animatedIcon = icons.getWaitIcon()
+ self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems)
+ self.__runnerSet = set([])
+
+ # store used icons to avoid to avoid the cache to release it
+ self.__icons = []
+ self.__icons.append(icons.getQIcon("item-0dim"))
+ self.__icons.append(icons.getQIcon("item-1dim"))
+ self.__icons.append(icons.getQIcon("item-2dim"))
+ self.__icons.append(icons.getQIcon("item-3dim"))
+ self.__icons.append(icons.getQIcon("item-ndim"))
+ self.__icons.append(icons.getQIcon("item-object"))
+
+ def __updateLoadingItems(self, icon):
+ for i in range(self.__root.childCount()):
+ item = self.__root.child(i)
+ if isinstance(item, Hdf5LoadingItem):
+ index1 = self.index(i, 0, qt.QModelIndex())
+ index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex())
+ self.dataChanged.emit(index1, index2)
+
+ def __itemReady(self, oldItem, newItem, error):
+ """Called at the end of a concurent file loading, when the loading
+ item is ready. AN error is defined if an exception occured when
+ loading the newItem .
+
+ :param Hdf5Node oldItem: current displayed item
+ :param Hdf5Node newItem: item loaded, or None if error is defined
+ :param Exception error: An exception, or None if newItem is defined
+ """
+ # Take care of None value in case of PySide
+ newItem = _unwrapNone(newItem)
+ error = _unwrapNone(error)
+ row = self.__root.indexOfChild(oldItem)
+ rootIndex = qt.QModelIndex()
+ self.beginRemoveRows(rootIndex, row, row)
+ self.__root.removeChildAtIndex(row)
+ self.endRemoveRows()
+ if newItem is not None:
+ self.beginInsertRows(rootIndex, row, row)
+ self.__root.insertChild(row, newItem)
+ self.endInsertRows()
+ # FIXME the error must be displayed
+
+ def isFileDropEnabled(self):
+ return self.__fileDropEnabled
+
+ def setFileDropEnabled(self, enabled):
+ self.__fileDropEnabled = enabled
+
+ fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled)
+ """Property to enable/disable file dropping in the model."""
+
+ def isFileMoveEnabled(self):
+ return self.__fileMoveEnabled
+
+ def setFileMoveEnabled(self, enabled):
+ self.__fileMoveEnabled = enabled
+
+ fileMoveEnabled = qt.Property(bool, isFileMoveEnabled, setFileMoveEnabled)
+ """Property to enable/disable drag-and-drop of files to
+ change the ordering in the model."""
+
+ def supportedDropActions(self):
+ if self.__fileMoveEnabled or self.__fileDropEnabled:
+ return qt.Qt.CopyAction | qt.Qt.MoveAction
+ else:
+ return 0
+
+ def mimeTypes(self):
+ if self.__fileMoveEnabled:
+ return [_utils.Hdf5NodeMimeData.MIME_TYPE]
+ else:
+ return []
+
+ def mimeData(self, indexes):
+ """
+ Returns an object that contains serialized items of data corresponding
+ to the list of indexes specified.
+
+ :param list(qt.QModelIndex) indexes: List of indexes
+ :rtype: qt.QMimeData
+ """
+ if not self.__fileMoveEnabled or len(indexes) == 0:
+ return None
+
+ indexes = [i for i in indexes if i.column() == 0]
+ if len(indexes) > 1:
+ raise NotImplementedError("Drag of multi rows is not implemented")
+ if len(indexes) == 0:
+ raise NotImplementedError("Drag of cell is not implemented")
+
+ node = self.nodeFromIndex(indexes[0])
+ mimeData = _utils.Hdf5NodeMimeData(node)
+ return mimeData
+
+ def flags(self, index):
+ defaultFlags = qt.QAbstractItemModel.flags(self, index)
+
+ if index.isValid():
+ node = self.nodeFromIndex(index)
+ if self.__fileMoveEnabled and node.parent is self.__root:
+ # that's a root
+ return qt.Qt.ItemIsDragEnabled | defaultFlags
+ return defaultFlags
+ elif self.__fileDropEnabled or self.__fileMoveEnabled:
+ return qt.Qt.ItemIsDropEnabled | defaultFlags
+ else:
+ return defaultFlags
+
+ def dropMimeData(self, mimedata, action, row, column, parentIndex):
+ if action == qt.Qt.IgnoreAction:
+ return True
+
+ if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5NodeMimeData.MIME_TYPE):
+ dragNode = mimedata.node()
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not dragNode.parent:
+ return False
+
+ if row == -1:
+ # append to the parent
+ row = parentNode.childCount()
+ else:
+ # insert at row
+ pass
+
+ dragNodeParent = dragNode.parent
+ sourceRow = dragNodeParent.indexOfChild(dragNode)
+ self.moveRow(parentIndex, sourceRow, parentIndex, row)
+ return True
+
+ if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
+
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not self.__root:
+ while(parentNode is not self.__root):
+ node = parentNode
+ parentNode = node.parent
+ row = parentNode.indexOfChild(node)
+ else:
+ if row == -1:
+ row = self.__root.childCount()
+
+ messages = []
+ for url in mimedata.urls():
+ try:
+ self.insertFileAsync(url.toLocalFile(), row)
+ row += 1
+ except IOError as e:
+ messages.append(e.args[0])
+ if len(messages) > 0:
+ title = "Error occurred when loading files"
+ message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages))
+ qt.QMessageBox.critical(None, title, message)
+ return True
+
+ return False
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ if orientation == qt.Qt.Horizontal:
+ if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
+ return self.header_labels[section]
+ return None
+
+ def insertNode(self, row, node):
+ if row == -1:
+ row = self.__root.childCount()
+ self.beginInsertRows(qt.QModelIndex(), row, row)
+ self.__root.insertChild(row, node)
+ self.endInsertRows()
+
+ def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow):
+ if sourceRow == destinationRow or sourceRow == destinationRow - 1:
+ # abort move, same place
+ return
+ return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow)
+
+ def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow):
+ self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow)
+ sourceNode = self.nodeFromIndex(sourceParentIndex)
+ destinationNode = self.nodeFromIndex(destinationParentIndex)
+
+ if sourceNode is destinationNode and sourceRow < destinationRow:
+ item = sourceNode.child(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+ sourceNode.removeChildAtIndex(sourceRow)
+ else:
+ item = sourceNode.removeChildAtIndex(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+
+ self.endMoveRows()
+ return True
+
+ def index(self, row, column, parent=qt.QModelIndex()):
+ try:
+ node = self.nodeFromIndex(parent)
+ return self.createIndex(row, column, node.child(row))
+ except IndexError:
+ return qt.QModelIndex()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ node = self.nodeFromIndex(index)
+
+ if role == self.H5PY_ITEM_ROLE:
+ return node
+
+ if role == self.H5PY_OBJECT_ROLE:
+ return node.obj
+
+ if index.column() == self.NAME_COLUMN:
+ return node.dataName(role)
+ elif index.column() == self.TYPE_COLUMN:
+ return node.dataType(role)
+ elif index.column() == self.SHAPE_COLUMN:
+ return node.dataShape(role)
+ elif index.column() == self.VALUE_COLUMN:
+ return node.dataValue(role)
+ elif index.column() == self.DESCRIPTION_COLUMN:
+ return node.dataDescription(role)
+ elif index.column() == self.NODE_COLUMN:
+ return node.dataNode(role)
+ else:
+ return None
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return len(self.header_labels)
+
+ def hasChildren(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.hasChildren()
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.childCount()
+
+ def parent(self, child):
+ if not child.isValid():
+ return qt.QModelIndex()
+
+ node = self.nodeFromIndex(child)
+
+ if node is None:
+ return qt.QModelIndex()
+
+ parent = node.parent
+
+ if parent is None:
+ return qt.QModelIndex()
+
+ grandparent = parent.parent
+ if grandparent is None:
+ return qt.QModelIndex()
+ row = grandparent.indexOfChild(parent)
+
+ assert row != - 1
+ return self.createIndex(row, 0, parent)
+
+ def nodeFromIndex(self, index):
+ return index.internalPointer() if index.isValid() else self.__root
+
+ def synchronizeIndex(self, index):
+ """
+ Synchronize a file a given its index.
+
+ Basically close it and load it again.
+
+ :param qt.QModelIndex index: Index of the item to update
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent is not self.__root:
+ return
+
+ self.removeIndex(index)
+ filename = node.obj.filename
+ node.obj.close()
+ self.insertFileAsync(filename, index.row())
+
+ def synchronizeH5pyObject(self, h5pyObject):
+ """
+ Synchronize a h5py object in all the tree.
+
+ Basically close it and load it again.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj is h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.synchronizeIndex(qindex)
+ else:
+ index += 1
+
+ def removeIndex(self, index):
+ """
+ Remove an item from the model using its index.
+
+ :param qt.QModelIndex index: Index of the item to remove
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent is not self.__root:
+ return
+ self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row())
+ self.__root.removeChildAtIndex(index.row())
+ self.endRemoveRows()
+
+ def removeH5pyObject(self, h5pyObject):
+ """
+ Remove an item from the model using the holding h5py object.
+ It can remove more than one item.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj is h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.removeIndex(qindex)
+ else:
+ index += 1
+
+ def insertH5pyObject(self, h5pyObject, text=None, row=-1):
+ """Append an HDF5 object from h5py to the tree.
+
+ :param h5pyObject: File handle/descriptor for a :class:`h5py.File`
+ or any other class of h5py file structure.
+ """
+ if text is None:
+ if silx_io.is_file(h5pyObject):
+ text = os.path.basename(h5pyObject.filename)
+ else:
+ filename = os.path.basename(h5pyObject.file.filename)
+ path = h5pyObject.name
+ text = "%s::%s" % (filename, path)
+ if row == -1:
+ row = self.__root.childCount()
+ self.insertNode(row, Hdf5Item(text=text, obj=h5pyObject, parent=self.__root))
+
+ def insertFileAsync(self, filename, row=-1):
+ if not os.path.isfile(filename):
+ raise IOError("Filename '%s' must be a file path" % filename)
+
+ # create temporary item
+ text = os.path.basename(filename)
+ item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon)
+ self.insertNode(row, item)
+
+ # start loading the real one
+ runnable = LoadingItemRunnable(filename, item)
+ runnable.itemReady.connect(self.__itemReady)
+ self.__runnerSet.add(runnable)
+ runnable.runnerFinished.connect(self.__releaseRunner)
+ qt.QThreadPool.globalInstance().start(runnable)
+
+ def __releaseRunner(self, runner):
+ self.__runnerSet.remove(runner)
+
+ def insertFile(self, filename, row=-1):
+ """Load a HDF5 file into the data model.
+
+ :param filename: file path.
+ """
+ try:
+ h5file = silx_io.open(filename)
+ self.insertH5pyObject(h5file, row=row)
+ except IOError:
+ _logger.debug("File '%s' can't be read.", filename, exc_info=True)
+ raise
+
+ def appendFile(self, filename):
+ self.insertFile(filename, -1)
diff --git a/silx/gui/hdf5/Hdf5TreeView.py b/silx/gui/hdf5/Hdf5TreeView.py
new file mode 100644
index 0000000..09f6fcf
--- /dev/null
+++ b/silx/gui/hdf5/Hdf5TreeView.py
@@ -0,0 +1,204 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "27/09/2016"
+
+
+import logging
+from .. import qt
+from ...utils import weakref as silxweakref
+from .Hdf5TreeModel import Hdf5TreeModel
+from .Hdf5HeaderView import Hdf5HeaderView
+from .NexusSortFilterProxyModel import NexusSortFilterProxyModel
+from .Hdf5Item import Hdf5Item
+from . import _utils
+
+_logger = logging.getLogger(__name__)
+
+
+class Hdf5TreeView(qt.QTreeView):
+ """TreeView which allow to browse HDF5 file structure.
+
+ It provides columns width auto-resizing and additional
+ signals.
+
+ The default model is a :class:`NexusSortFilterProxyModel` sourcing
+ a :class:`Hdf5TreeModel`. The :class:`Hdf5TreeModel` is reachable using
+ :meth:`findHdf5TreeModel`. The default header is :class:`Hdf5HeaderView`.
+
+ Context menu is managed by the :meth:`setContextMenuPolicy` with the value
+ Qt.CustomContextMenu. This policy must not be changed, otherwise context
+ menus will not work anymore. You can use :meth:`addContextMenuCallback` and
+ :meth:`removeContextMenuCallback` to add your custum actions according
+ to the selected objects.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param parent qt.QWidget: The parent widget
+ """
+ qt.QTreeView.__init__(self, parent)
+
+ model = Hdf5TreeModel(self)
+ proxy_model = NexusSortFilterProxyModel(self)
+ proxy_model.setSourceModel(model)
+ self.setModel(proxy_model)
+
+ self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self))
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ # optimise the rendering
+ self.setUniformRowHeights(True)
+
+ self.setIconSize(qt.QSize(16, 16))
+ self.setAcceptDrops(True)
+ self.setDragEnabled(True)
+ self.setDragDropMode(qt.QAbstractItemView.DragDrop)
+ self.showDropIndicator()
+
+ self.__context_menu_callbacks = silxweakref.WeakList()
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self._createContextMenu)
+
+ def __removeContextMenuProxies(self, ref):
+ """Callback to remove dead proxy from the list"""
+ self.__context_menu_callbacks.remove(ref)
+
+ def _createContextMenu(self, pos):
+ """
+ Create context menu.
+
+ :param pos qt.QPoint: Position of the context menu
+ """
+ actions = []
+
+ menu = qt.QMenu(self)
+
+ hovered_index = self.indexAt(pos)
+ hovered_node = self.model().data(hovered_index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if hovered_node is None or not isinstance(hovered_node, Hdf5Item):
+ return
+
+ hovered_object = _utils.H5Node(hovered_node)
+ event = _utils.Hdf5ContextMenuEvent(self, menu, hovered_object)
+
+ for callback in self.__context_menu_callbacks:
+ try:
+ callback(event)
+ except KeyboardInterrupt:
+ raise
+ except:
+ # make sure no user callback crash the application
+ _logger.error("Error while calling callback", exc_info=True)
+ pass
+
+ if len(menu.children()) > 0:
+ for action in actions:
+ menu.addAction(action)
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def addContextMenuCallback(self, callback):
+ """Register a context menu callback.
+
+ The callback will be called when a context menu is requested with the
+ treeview and the list of selected h5py objects in parameters. The
+ callback must return a list of :class:`qt.QAction` object.
+
+ Callbacks are stored as saferef. The object must store a reference by
+ itself.
+ """
+ self.__context_menu_callbacks.append(callback)
+
+ def removeContextMenuCallback(self, callback):
+ """Unregister a context menu callback"""
+ self.__context_menu_callbacks.remove(callback)
+
+ def findHdf5TreeModel(self):
+ """Find the Hdf5TreeModel from the stack of model filters.
+
+ :returns: A Hdf5TreeModel, else None
+ :rtype: Hdf5TreeModel
+ """
+ model = self.model()
+ while model is not None:
+ if isinstance(model, qt.QAbstractProxyModel):
+ model = model.sourceModel()
+ else:
+ break
+ if model is None:
+ return None
+ if isinstance(model, Hdf5TreeModel):
+ return model
+ else:
+ return None
+
+ def dragEnterEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ self.setState(qt.QAbstractItemView.DraggingState)
+ event.accept()
+ else:
+ qt.QTreeView.dragEnterEvent(self, event)
+
+ def dragMoveEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ event.setDropAction(qt.Qt.CopyAction)
+ event.accept()
+ else:
+ qt.QTreeView.dragMoveEvent(self, event)
+
+ def selectedH5Nodes(self, ignoreBrokenLinks=True):
+ """Returns selected h5py objects like :class:`h5py.File`,
+ :class:`h5py.Group`, :class:`h5py.Dataset` or mimicked objects.
+
+ :param ignoreBrokenLinks bool: Returns objects which are not not
+ broken links.
+ :rtype: iterator(:class:`_utils.H5Node`)
+ """
+ for index in self.selectedIndexes():
+ if index.column() != 0:
+ continue
+ item = self.model().data(index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if item is None:
+ continue
+ if isinstance(item, Hdf5Item):
+ if ignoreBrokenLinks and item.isBrokenObj():
+ continue
+ yield _utils.H5Node(item)
+
+ def mousePressEvent(self, event):
+ """Override mousePressEvent to provide a consistante compatible API
+ between Qt4 and Qt5
+ """
+ super(Hdf5TreeView, self).mousePressEvent(event)
+ if event.button() != qt.Qt.LeftButton:
+ # Qt5 only sends itemClicked on left button mouse click
+ if qt.qVersion() > "5":
+ qindex = self.indexAt(event.pos())
+ self.clicked.emit(qindex)
diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/silx/gui/hdf5/NexusSortFilterProxyModel.py
new file mode 100644
index 0000000..9a4268c
--- /dev/null
+++ b/silx/gui/hdf5/NexusSortFilterProxyModel.py
@@ -0,0 +1,152 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/04/2017"
+
+
+import logging
+import re
+import numpy
+from .. import qt
+from .Hdf5TreeModel import Hdf5TreeModel
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import h5py
+except ImportError as e:
+ _logger.error("Module %s requires h5py", __name__)
+ raise e
+
+_logger = logging.getLogger(__name__)
+
+
+class NexusSortFilterProxyModel(qt.QSortFilterProxyModel):
+ """Try to sort items according to Nexus structure. Else sort by name."""
+
+ def __init__(self, parent=None):
+ qt.QSortFilterProxyModel.__init__(self, parent)
+ self.__split = re.compile("(\\d+|\\D+)")
+
+ def lessThan(self, sourceLeft, sourceRight):
+ """Returns True if the value of the item referred to by the given
+ index `sourceLeft` is less than the value of the item referred to by
+ the given index `sourceRight`, otherwise returns false.
+
+ :param qt.QModelIndex sourceLeft:
+ :param qt.QModelIndex sourceRight:
+ :rtype: bool
+ """
+ if sourceLeft.column() != Hdf5TreeModel.NAME_COLUMN:
+ return super(NexusSortFilterProxyModel, self).lessThan(
+ sourceLeft, sourceRight)
+
+ # Do not sort child of root (files)
+ if sourceLeft.parent() == qt.QModelIndex():
+ return sourceLeft.row() < sourceRight.row()
+
+ left = self.sourceModel().data(sourceLeft, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ right = self.sourceModel().data(sourceRight, Hdf5TreeModel.H5PY_ITEM_ROLE)
+
+ if self.__isNXentry(left) and self.__isNXentry(right):
+ less = self.childDatasetLessThan(left, right, "start_time")
+ if less is not None:
+ return less
+ less = self.childDatasetLessThan(left, right, "end_time")
+ if less is not None:
+ return less
+
+ left = self.sourceModel().data(sourceLeft, qt.Qt.DisplayRole)
+ right = self.sourceModel().data(sourceRight, qt.Qt.DisplayRole)
+ return self.nameLessThan(left, right)
+
+ def __isNXentry(self, node):
+ """Returns true if the node is an NXentry"""
+ if not issubclass(node.h5pyClass, h5py.Group):
+ return False
+ nxClass = node.obj.attrs.get("NX_class", None)
+ return nxClass == "NXentry"
+
+ def getWordsAndNumbers(self, name):
+ """
+ Returns a list of words and integers composing the name.
+
+ An input `"aaa10bbb50.30"` will return
+ `["aaa", 10, "bbb", 50, ".", 30]`.
+
+ :param str name: A name
+ :rtype: list
+ """
+ words = self.__split.findall(name)
+ result = []
+ for i in words:
+ if i[0].isdigit():
+ i = int(i)
+ result.append(i)
+ return result
+
+ def nameLessThan(self, left, right):
+ """Returns True if the left string is less than the right string.
+
+ Number composing the names are compared as integers, as result "name2"
+ is smaller than "name10".
+
+ :param str left: A string
+ :param str right: A string
+ :rtype: bool
+ """
+ leftList = self.getWordsAndNumbers(left)
+ rightList = self.getWordsAndNumbers(right)
+ try:
+ return leftList < rightList
+ except TypeError:
+ # Back to string comparison if list are not type consistent
+ return left < right
+
+ def childDatasetLessThan(self, left, right, childName):
+ """
+ Reach the same children name of two items and compare their values.
+
+ Returns True if the left one is smaller than the right one.
+
+ :param Hdf5Item left: An item
+ :param Hdf5Item right: An item
+ :param str childName: Name of the children to search. Returns None if
+ the children is not found.
+ :rtype: bool
+ """
+ try:
+ left_time = left.obj[childName][()]
+ right_time = right.obj[childName][()]
+ if isinstance(left_time, numpy.ndarray):
+ return left_time[0] < right_time[0]
+ return left_time < right_time
+ except KeyboardInterrupt:
+ raise
+ except Exception as e:
+ _logger.debug("Exception occurred", exc_info=True)
+ return None
diff --git a/silx/gui/hdf5/__init__.py b/silx/gui/hdf5/__init__.py
new file mode 100644
index 0000000..1b5a602
--- /dev/null
+++ b/silx/gui/hdf5/__init__.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for displaying content relative to
+HDF5 format.
+
+.. note::
+
+ This package depends on *h5py*.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/09/2016"
+
+
+from .Hdf5TreeView import Hdf5TreeView # noqa
+from ._utils import H5Node
+from ._utils import Hdf5ContextMenuEvent # noqa
+from .NexusSortFilterProxyModel import NexusSortFilterProxyModel # noqa
+from .Hdf5TreeModel import Hdf5TreeModel # noqa
+
+__all__ = ['Hdf5TreeView', 'H5Node', 'Hdf5ContextMenuEvent', 'NexusSortFilterProxyModel', 'Hdf5TreeModel']
diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py
new file mode 100644
index 0000000..af9c79f
--- /dev/null
+++ b/silx/gui/hdf5/_utils.py
@@ -0,0 +1,247 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of helper class and function used by the
+package `silx.gui.hdf5` package.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+
+import logging
+import numpy
+from .. import qt
+import silx.io.utils
+from silx.utils.html import escape
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import h5py
+except ImportError as e:
+ _logger.error("Module %s requires h5py", __name__)
+ raise e
+
+
+class Hdf5ContextMenuEvent(object):
+ """Hold information provided to context menu callbacks."""
+
+ def __init__(self, source, menu, hoveredObject):
+ """
+ Constructor
+
+ :param QWidget source: Widget source
+ :param QMenu menu: Context menu which will be displayed
+ :param H5Node hoveredObject: Hovered H5 node
+ """
+ self.__source = source
+ self.__menu = menu
+ self.__hoveredObject = hoveredObject
+
+ def source(self):
+ """Source of the event
+
+ :rtype: Hdf5TreeView
+ """
+ return self.__source
+
+ def menu(self):
+ """Menu which will be displayed
+
+ :rtype: qt.QMenu
+ """
+ return self.__menu
+
+ def hoveredObject(self):
+ """Item content hovered by the mouse when the context menu was
+ requested
+
+ :rtype: H5Node
+ """
+ return self.__hoveredObject
+
+
+def htmlFromDict(dictionary, title=None):
+ """Generate a readable HTML from a dictionary
+
+ :param dict dictionary: A Dictionary
+ :rtype: str
+ """
+ result = """<html>
+ <head>
+ <style type="text/css">
+ ul { -qt-list-indent: 0; list-style: none; }
+ li > b {display: inline-block; min-width: 4em; font-weight: bold; }
+ </style>
+ </head>
+ <body>
+ """
+ if title is not None:
+ result += "<b>%s</b>" % escape(title)
+ result += "<ul>"
+ for key, value in dictionary.items():
+ result += "<li><b>%s</b>: %s</li>" % (escape(key), escape(value))
+ result += "</ul>"
+ result += "</body></html>"
+ return result
+
+
+class Hdf5NodeMimeData(qt.QMimeData):
+ """Mimedata class to identify an internal drag and drop of a Hdf5Node."""
+
+ MIME_TYPE = "application/x-internal-h5py-node"
+
+ def __init__(self, node=None):
+ qt.QMimeData.__init__(self)
+ self.__node = node
+ self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
+
+ def node(self):
+ return self.__node
+
+
+class H5Node(object):
+ """Adapter over an h5py object to provide missing informations from h5py
+ nodes, like internal node path and filename (which are not provided by
+ :mod:`h5py` for soft and external links).
+
+ It also provides an abstraction to reach node type for mimicked h5py
+ objects.
+ """
+
+ def __init__(self, h5py_item=None):
+ """Constructor
+
+ :param Hdf5Item h5py_item: An Hdf5Item
+ """
+ self.__h5py_object = h5py_item.obj
+ self.__h5py_item = h5py_item
+
+ def __getattr__(self, name):
+ return object.__getattribute__(self.__h5py_object, name)
+
+ @property
+ def h5py_object(self):
+ """Returns the internal h5py node.
+
+ :rtype: h5py.File or h5py.Group or h5py.Dataset
+ """
+ return self.__h5py_object
+
+ @property
+ def ntype(self):
+ """Returns the node type, as an h5py class.
+
+ :rtype:
+ :class:`h5py.File`, :class:`h5py.Group` or :class:`h5py.Dataset`
+ """
+ return silx.io.utils.get_h5py_class(self.__h5py_object)
+
+ @property
+ def basename(self):
+ """Returns the basename of this h5py node. It is the last identifier of
+ the path.
+
+ :rtype: str
+ """
+ return self.__h5py_object.name.split("/")[-1]
+
+ @property
+ def local_name(self):
+ """Returns the local path of this h5py node.
+
+ For links, this path is not equal to the h5py one.
+
+ :rtype: str
+ """
+ if self.__h5py_item is None:
+ raise RuntimeError("h5py_item is not defined")
+
+ result = []
+ item = self.__h5py_item
+ while item is not None:
+ if issubclass(item.h5pyClass, h5py.File):
+ break
+ result.append(item.basename)
+ item = item.parent
+ if item is None:
+ raise RuntimeError("The item does not have parent holding h5py.File")
+ if result == []:
+ return "/"
+ result.append("")
+ result.reverse()
+ return "/".join(result)
+
+ def __file_item(self):
+ """Returns the parent item holding the :class:`h5py.File` object
+
+ :rtype: h5py.File
+ :raises RuntimeException: If no file are found
+ """
+ item = self.__h5py_item
+ while item is not None:
+ if issubclass(item.h5pyClass, h5py.File):
+ return item
+ item = item.parent
+ raise RuntimeError("The item does not have parent holding h5py.File")
+
+ @property
+ def local_file(self):
+ """Returns the local :class:`h5py.File` object.
+
+ For path containing external links, this file is not equal to the h5py
+ one.
+
+ :rtype: h5py.File
+ :raises RuntimeException: If no file are found
+ """
+ item = self.__file_item()
+ return item.obj
+
+ @property
+ def local_filename(self):
+ """Returns the local filename of the h5py node.
+
+ For path containing external links, this path is not equal to the
+ filename provided by h5py.
+
+ :rtype: str
+ :raises RuntimeException: If no file are found
+ """
+ return self.local_file.filename
+
+ @property
+ def local_basename(self):
+ """Returns the local filename of the h5py node.
+
+ For path containing links, this basename can be different than the
+ basename provided by h5py.
+
+ :rtype: str
+ """
+ if issubclass(self.__h5py_item.h5pyClass, h5py.File):
+ return ""
+ return self.__h5py_item.basename
diff --git a/silx/gui/hdf5/setup.py b/silx/gui/hdf5/setup.py
new file mode 100644
index 0000000..786a851
--- /dev/null
+++ b/silx/gui/hdf5/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/09/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('hdf5', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/silx/gui/hdf5/test/__init__.py b/silx/gui/hdf5/test/__init__.py
new file mode 100644
index 0000000..3000d96
--- /dev/null
+++ b/silx/gui/hdf5/test/__init__.py
@@ -0,0 +1,39 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from . import test_hdf5
+
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/09/2016"
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTests(
+ [test_hdf5.suite()])
+ return test_suite
diff --git a/silx/gui/hdf5/test/_mock.py b/silx/gui/hdf5/test/_mock.py
new file mode 100644
index 0000000..eada590
--- /dev/null
+++ b/silx/gui/hdf5/test/_mock.py
@@ -0,0 +1,130 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Mock for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/04/2017"
+
+
+import numpy
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+class Node(object):
+
+ def __init__(self, basename, parent, h5py_class):
+ self.basename = basename
+ self.h5py_class = h5py_class
+ self.attrs = {}
+ self.parent = parent
+ if parent is not None:
+ self.parent._add(self)
+
+ @property
+ def name(self):
+ if self.parent is None:
+ return self.basename
+ if self.parent.name == "":
+ return self.basename
+ return self.parent.name + "/" + self.basename
+
+ @property
+ def file(self):
+ if self.parent is None:
+ return self
+ return self.parent.file
+
+
+class Group(Node):
+ """Mock an h5py Group"""
+
+ def __init__(self, name, parent, h5py_class=h5py.Group):
+ super(Group, self).__init__(name, parent, h5py_class)
+ self.__items = {}
+
+ def _add(self, node):
+ self.__items[node.basename] = node
+
+ def __getitem__(self, key):
+ return self.__items[key]
+
+ def __iter__(self):
+ for k in self.__items:
+ yield k
+
+ def __len__(self):
+ return len(self.__items)
+
+ def get(self, name, getclass=False, getlink=False):
+ result = self.__items[name]
+ if getclass:
+ return result.h5py_class
+ return result
+
+ def create_dataset(self, name, data):
+ return Dataset(name, self, data)
+
+ def create_group(self, name):
+ return Group(name, self)
+
+ def create_NXentry(self, name):
+ group = Group(name, self)
+ group.attrs["NX_class"] = "NXentry"
+ return group
+
+
+class File(Group):
+ """Mock an h5py File"""
+
+ def __init__(self, filename):
+ super(File, self).__init__("", None, h5py.File)
+ self.filename = filename
+
+
+class Dataset(Node):
+ """Mock an h5py Dataset"""
+
+ def __init__(self, name, parent, value):
+ super(Dataset, self).__init__(name, parent, h5py.Dataset)
+ self.__value = value
+ self.shape = self.__value.shape
+ self.dtype = self.__value.dtype
+ self.size = self.__value.size
+ self.compression = None
+ self.compression_opts = None
+
+ def __getitem__(self, key):
+ if not isinstance(self.__value, numpy.ndarray):
+ if key == tuple():
+ return self.__value
+ elif key == Ellipsis:
+ return numpy.array(self.__value)
+ else:
+ raise ValueError("Bad key")
+ return self.__value[key]
diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py
new file mode 100644
index 0000000..3bf4897
--- /dev/null
+++ b/silx/gui/hdf5/test/test_hdf5.py
@@ -0,0 +1,480 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/04/2017"
+
+
+import time
+import os
+import unittest
+import tempfile
+import numpy
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+from silx.gui import hdf5
+from . import _mock
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+_called = 0
+
+
+class _Holder(object):
+ def callback(self, *args, **kvargs):
+ _called += 1
+
+
+class TestHdf5TreeModel(TestCaseQt):
+
+ def setUp(self):
+ super(TestHdf5TreeModel, self).setUp()
+ if h5py is None:
+ self.skipTest("h5py is not available")
+
+ @contextmanager
+ def h5TempFile(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ g = h5file.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ h5file.close()
+ yield tmp_name
+ # clean up
+ os.unlink(tmp_name)
+
+ def testCreate(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertIsNotNone(model)
+
+ def testAppendFilename(self):
+ with self.h5TempFile() as filename:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+ model.appendFile(filename)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ h5File.close()
+
+ def testAppendBadFilename(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertRaises(IOError, model.appendFile, "#%$")
+
+ def testInsertFilename(self):
+ with self.h5TempFile() as filename:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(filename)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ h5File.close()
+
+ def testInsertFilenameAsync(self):
+ with self.h5TempFile() as filename:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFileAsync(filename)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem)
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ h5File.close()
+
+ def testInsertObject(self):
+ h5 = _mock.File("/foo/bar/1.mock")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+
+ def testRemoveObject(self):
+ h5 = _mock.File("/foo/bar/1.mock")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ model.removeH5pyObject(h5)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 0)
+
+ def testSynchronizeObject(self):
+ with self.h5TempFile() as filename:
+ h5 = h5py.File(filename)
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0, qt.QModelIndex())
+ node1 = model.nodeFromIndex(index)
+ model.synchronizeH5pyObject(h5)
+ index = model.index(0, 0, qt.QModelIndex())
+ node2 = model.nodeFromIndex(index)
+ self.assertIsNot(node1, node2)
+ # after sync
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ h5File.close()
+
+ def testFileMoveState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.isFileMoveEnabled(), True)
+ model.setFileMoveEnabled(False)
+ self.assertEquals(model.isFileMoveEnabled(), False)
+
+ def testFileDropState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEquals(model.isFileDropEnabled(), True)
+ model.setFileDropEnabled(False)
+ self.assertEquals(model.isFileDropEnabled(), False)
+
+ def testSupportedDrop(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertNotEquals(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(False)
+ self.assertEquals(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(True)
+ self.assertNotEquals(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(True)
+ model.setFileDropEnabled(False)
+ self.assertNotEquals(model.supportedDropActions(), 0)
+
+ def testDropExternalFile(self):
+ with self.h5TempFile() as filename:
+ model = hdf5.Hdf5TreeModel()
+ mimeData = qt.QMimeData()
+ mimeData.setUrls([qt.QUrl.fromLocalFile(filename)])
+ model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex())
+ self.assertEquals(model.rowCount(qt.QModelIndex()), 1)
+ # after sync
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ h5File.close()
+
+ def getRowDataAsDict(self, model, row):
+ displayed = {}
+ roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole]
+ for column in range(0, model.columnCount(qt.QModelIndex())):
+ index = model.index(0, column, qt.QModelIndex())
+ for role in roles:
+ datum = model.data(index, role)
+ displayed[column, role] = datum
+ return displayed
+
+ def getItemName(self, model, row):
+ index = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, qt.QModelIndex())
+ return model.data(index, qt.Qt.DisplayRole)
+
+ def testFileData(self):
+ h5 = _mock.File("/foo/bar/1.mock")
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File")
+
+ def testGroupData(self):
+ h5 = _mock.File("/foo/bar/1.mock")
+ d = h5.create_group("foo")
+ d.attrs["desc"] = "fooo"
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group")
+
+ def testDatasetData(self):
+ h5 = _mock.File("/foo/bar/1.mock")
+ value = numpy.array([1, 2, 3])
+ d = h5.create_dataset("foo", value)
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name)
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEquals(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset")
+
+ def testDropLastAsFirst(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = _mock.File("/foo/bar/1.mock")
+ h5_2 = _mock.File("/foo/bar/2.mock")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEquals(self.getItemName(model, 0), "1.mock")
+ self.assertEquals(self.getItemName(model, 1), "2.mock")
+ index = model.index(1, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 0, 0, qt.QModelIndex())
+ self.assertEquals(self.getItemName(model, 0), "2.mock")
+ self.assertEquals(self.getItemName(model, 1), "1.mock")
+
+ def testDropFirstAsLast(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = _mock.File("/foo/bar/1.mock")
+ h5_2 = _mock.File("/foo/bar/2.mock")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEquals(self.getItemName(model, 0), "1.mock")
+ self.assertEquals(self.getItemName(model, 1), "2.mock")
+ index = model.index(0, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 2, 0, qt.QModelIndex())
+ self.assertEquals(self.getItemName(model, 0), "2.mock")
+ self.assertEquals(self.getItemName(model, 1), "1.mock")
+
+ def testRootParent(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = _mock.File("/foo/bar/1.mock")
+ model.insertH5pyObject(h5_1)
+ index = model.index(0, 0, qt.QModelIndex())
+ index = model.parent(index)
+ self.assertEquals(index, qt.QModelIndex())
+
+
+class TestNexusSortFilterProxyModel(TestCaseQt):
+
+ def getChildNames(self, model, index):
+ count = model.rowCount(index)
+ result = []
+ for row in range(0, count):
+ itemIndex = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, index)
+ name = model.data(itemIndex, qt.Qt.DisplayRole)
+ result.append(name)
+ return result
+
+ def testNXentryStartTime(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_NXentry("a").create_dataset("start_time", numpy.string_("2015"))
+ h5.create_NXentry("b").create_dataset("start_time", numpy.string_("2013"))
+ h5.create_NXentry("c").create_dataset("start_time", numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryStartTimeInArray(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_NXentry("a").create_dataset("start_time", numpy.array([numpy.string_("2015")]))
+ h5.create_NXentry("b").create_dataset("start_time", numpy.array([numpy.string_("2013")]))
+ h5.create_NXentry("c").create_dataset("start_time", numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryEndTimeInArray(self):
+ """Test NXentry with end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_NXentry("a").create_dataset("end_time", numpy.array([numpy.string_("2015")]))
+ h5.create_NXentry("b").create_dataset("end_time", numpy.array([numpy.string_("2013")]))
+ h5.create_NXentry("c").create_dataset("end_time", numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryName(self):
+ """Test NXentry without start_time or end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_NXentry("a")
+ h5.create_NXentry("c")
+ h5.create_NXentry("b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testStartTime(self):
+ """If it is not NXentry, start_time is not used"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_group("a").create_dataset("start_time", numpy.string_("2015"))
+ h5.create_group("b").create_dataset("start_time", numpy.string_("2013"))
+ h5.create_group("c").create_dataset("start_time", numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testName(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_group("a")
+ h5.create_group("c")
+ h5.create_group("b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_group("a1")
+ h5.create_group("a20")
+ h5.create_group("a3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1", "a3", "a20"])
+
+ def testMultiNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_group("a1-1")
+ h5.create_group("a20-1")
+ h5.create_group("a3-1")
+ h5.create_group("a3-20")
+ h5.create_group("a3-3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1-1", "a3-1", "a3-3", "a3-20", "a20-1"])
+
+ def testUnconsistantTypes(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = _mock.File("/foo/bar/1.mock")
+ h5.create_group("aaa100")
+ h5.create_group("100aaa")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["100aaa", "aaa100"])
+
+
+class TestHdf5(TestCaseQt):
+ """Test to check that icons module."""
+
+ def setUp(self):
+ super(TestHdf5, self).setUp()
+ if h5py is None:
+ self.skipTest("h5py is not available")
+
+ def testCreate(self):
+ view = hdf5.Hdf5TreeView()
+ self.assertIsNotNone(view)
+
+ def testContextMenu(self):
+ view = hdf5.Hdf5TreeView()
+ view._createContextMenu(qt.QPoint(0, 0))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestHdf5TreeModel))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestNexusSortFilterProxyModel))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestHdf5))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/icons.py b/silx/gui/icons.py
new file mode 100644
index 0000000..eaf83b8
--- /dev/null
+++ b/silx/gui/icons.py
@@ -0,0 +1,360 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Set of icons for buttons.
+
+Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/04/2017"
+
+
+import logging
+import weakref
+from . import qt
+from silx.resources import resource_filename
+from silx.utils import weakref as silxweakref
+from silx.utils.decorators import deprecated
+
+
+_logger = logging.getLogger(__name__)
+"""Module logger"""
+
+
+_cached_icons = weakref.WeakValueDictionary()
+"""Cache loaded icons in a weak structure"""
+
+
+_supported_formats = None
+"""Order of file format extension to check"""
+
+
+class AbstractAnimatedIcon(qt.QObject):
+ """Store an animated icon.
+
+ It provides an event containing the new icon everytime it is updated."""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ qt.QObject.__init__(self, parent)
+
+ self.__targets = silxweakref.WeakList()
+ self.__currentIcon = None
+
+ iconChanged = qt.Signal(qt.QIcon)
+ """Signal sent with a QIcon everytime the animation changed."""
+
+ def register(self, obj):
+ """Register an object to the AnimatedIcon.
+ If no object are registred, the animation is paused.
+ Object are stored in a weaked list.
+
+ :param object obj: An object
+ """
+ if obj not in self.__targets:
+ self.__targets.append(obj)
+ self._updateState()
+
+ def unregister(self, obj):
+ """Remove the object from the registration.
+ If no object are registred the animation is paused.
+
+ :param object obj: A registered object
+ """
+ if obj in self.__targets:
+ self.__targets.remove(obj)
+ self._updateState()
+
+ def hasRegistredObjects(self):
+ """Returns true if any object is registred.
+
+ :rtype: bool
+ """
+ return len(self.__targets)
+
+ def isRegistered(self, obj):
+ """Returns true if the object is registred in the AnimatedIcon.
+
+ :param object obj: An object
+ :rtype: bool
+ """
+ return obj in self.__targets
+
+ def currentIcon(self):
+ """Returns the icon of the current frame.
+
+ :rtype: qt.QIcon
+ """
+ return self.__currentIcon
+
+ def _updateState(self):
+ """Update the object according to the connected objects."""
+ pass
+
+ def _setCurrentIcon(self, icon):
+ """Store the current icon and emit a `iconChanged` event.
+
+ :param qt.QIcon icon: The current icon
+ """
+ self.__currentIcon = icon
+ self.iconChanged.emit(self.__currentIcon)
+
+
+class MovieAnimatedIcon(AbstractAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated."""
+
+ def __init__(self, filename, parent=None):
+ """Constructor
+
+ :param str filename: An icon name to an animated format
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ AbstractAnimatedIcon.__init__(self, parent)
+
+ qfile = getQFile(filename)
+ self.__movie = qt.QMovie(qfile.fileName(), qt.QByteArray(), parent)
+ self.__movie.setCacheMode(qt.QMovie.CacheAll)
+ self.__movie.frameChanged.connect(self.__frameChanged)
+ self.__cacheIcons = {}
+
+ self.__movie.jumpToFrame(0)
+ self.__updateIconAtFrame(0)
+
+ def __frameChanged(self, frameId):
+ """Callback everytime the QMovie frame change
+ :param int frameId: Current frame id
+ """
+ self.__updateIconAtFrame(frameId)
+
+ def __updateIconAtFrame(self, frameId):
+ """
+ Update the current stored QIcon
+
+ :param int frameId: Current frame id
+ """
+ if frameId in self.__cacheIcons:
+ icon = self.__cacheIcons[frameId]
+ else:
+ icon = qt.QIcon(self.__movie.currentPixmap())
+ self.__cacheIcons[frameId] = icon
+ self._setCurrentIcon(icon)
+
+ def _updateState(self):
+ """Update the movie play according to internal stat of the
+ AnimatedIcon."""
+ self.__movie.setPaused(not self.hasRegistredObjects())
+
+
+class MultiImageAnimatedIcon(AbstractAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated."""
+
+ def __init__(self, filename, parent=None):
+ """Constructor
+
+ :param str filename: An icon name to an animated format
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ AbstractAnimatedIcon.__init__(self, parent)
+
+ self.__frames = []
+ for i in range(100):
+ try:
+ pixmap = getQPixmap("animated/%s-%02d" % (filename, i))
+ except ValueError:
+ break
+ icon = qt.QIcon(pixmap)
+ self.__frames.append(icon)
+
+ if len(self.__frames) == 0:
+ raise ValueError("Animated icon '%s' do not exists" % filename)
+
+ self.__frameId = -1
+ self.__timer = qt.QTimer(self)
+ self.__timer.timeout.connect(self.__increaseFrame)
+ self.__updateIconAtFrame(0)
+
+ def __increaseFrame(self):
+ """Callback called every timer timeout to change the current frame of
+ the animation
+ """
+ frameId = (self.__frameId + 1) % len(self.__frames)
+ self.__updateIconAtFrame(frameId)
+
+ def __updateIconAtFrame(self, frameId):
+ """
+ Update the current stored QIcon
+
+ :param int frameId: Current frame id
+ """
+ self.__frameId = frameId
+ icon = self.__frames[frameId]
+ self._setCurrentIcon(icon)
+
+ def _updateState(self):
+ """Update the object to wake up or sleep it according to its use."""
+ if self.hasRegistredObjects():
+ if not self.__timer.isActive():
+ self.__timer.start(100)
+ else:
+ if self.__timer.isActive():
+ self.__timer.stop()
+
+
+class AnimatedIcon(MovieAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated.
+
+ It may not be available anymore for the silx release 0.6.
+
+ .. deprecated:: 0.5
+ Use :class:`MovieAnimatedIcon` instead.
+ """
+
+ @deprecated
+ def __init__(self, filename, parent=None):
+ MovieAnimatedIcon.__init__(self, filename, parent=parent)
+
+
+def getWaitIcon():
+ """Returns a cached version of the waiting AbstractAnimatedIcon.
+
+ :rtype: AbstractAnimatedIcon
+ """
+ return getAnimatedIcon("process-working")
+
+
+def getAnimatedIcon(name):
+ """Create an AbstractAnimatedIcon from a name.
+
+ Try to load a mng or a gif file, then try to load a multi-image animated
+ icon.
+
+ In Qt5 mng or gif are not used. It does not take care very well of the
+ transparency.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding AbstractAnimatedIcon
+ :raises: ValueError when name is not known
+ """
+ key = name + "__anim"
+ if key not in _cached_icons:
+
+ qtMajorVersion = int(qt.qVersion().split(".")[0])
+ icon = None
+
+ # ignore mng and gif in Qt5
+ if qtMajorVersion != 5:
+ try:
+ icon = MovieAnimatedIcon(name)
+ except ValueError:
+ icon = None
+
+ if icon is None:
+ try:
+ icon = MultiImageAnimatedIcon(name)
+ except ValueError:
+ icon = None
+
+ if icon is None:
+ raise ValueError("Not an animated icon name: %s", name)
+
+ _cached_icons[key] = icon
+ else:
+ icon = _cached_icons[key]
+ return icon
+
+
+def getQIcon(name):
+ """Create a QIcon from its name.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QIcon
+ :raises: ValueError when name is not known
+ """
+ if name not in _cached_icons:
+ qfile = getQFile(name)
+ icon = qt.QIcon(qfile.fileName())
+ _cached_icons[name] = icon
+ else:
+ icon = _cached_icons[name]
+ return icon
+
+
+def getQPixmap(name):
+ """Create a QPixmap from its name.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QPixmap
+ :raises: ValueError when name is not known
+ """
+ qfile = getQFile(name)
+ return qt.QPixmap(qfile.fileName())
+
+
+def getQFile(name):
+ """Create a QFile from an icon name. Filename is found
+ according to supported Qt formats.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QFile
+ :rtype: qt.QFile
+ :raises: ValueError when name is not known
+ """
+ global _supported_formats
+ if _supported_formats is None:
+ _supported_formats = []
+ supported_formats = qt.supportedImageFormats()
+ order = ["mng", "gif", "svg", "png", "jpg"]
+ for format_ in order:
+ if format_ in supported_formats:
+ _supported_formats.append(format_)
+ if len(_supported_formats) == 0:
+ _logger.error("No format supported for icons")
+ else:
+ _logger.debug("Format %s supported", ", ".join(_supported_formats))
+
+ for format_ in _supported_formats:
+ format_ = str(format_)
+ filename = resource_filename('gui/icons/%s.%s' % (name, format_))
+ qfile = qt.QFile(filename)
+ if qfile.exists():
+ return qfile
+ raise ValueError('Not an icon name: %s' % name)
diff --git a/silx/gui/plot/AlphaSlider.py b/silx/gui/plot/AlphaSlider.py
new file mode 100644
index 0000000..ab2e5aa
--- /dev/null
+++ b/silx/gui/plot/AlphaSlider.py
@@ -0,0 +1,300 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines slider widgets interacting with the transparency
+of an image on a :class:`PlotWidget`
+
+Classes:
+--------
+
+- :class:`BaseAlphaSlider` (abstract class)
+- :class:`NamedImageAlphaSlider`
+- :class:`ActiveImageAlphaSlider`
+
+Example:
+--------
+
+This widget can, for instance, be added to a plot toolbar.
+
+.. code-block:: python
+
+ import numpy
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider
+
+ app = qt.QApplication([])
+ pw = PlotWidget()
+
+ img0 = numpy.arange(200*150).reshape((200, 150))
+ pw.addImage(img0, legend="my background", z=0, origin=(50, 50))
+
+ x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200),
+ numpy.linspace(-10, 5, 150),
+ indexing="ij")
+ img1 = numpy.asarray(numpy.sin(x * y) / (x * y),
+ dtype='float32')
+
+ pw.addImage(img1, legend="my data", z=1,
+ replace=False)
+
+ alpha_slider = NamedImageAlphaSlider(parent=pw,
+ plot=pw,
+ legend="my data")
+ alpha_slider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", pw)
+ toolbar.addWidget(alpha_slider)
+ pw.addToolBar(toolbar)
+
+ pw.show()
+ app.exec_()
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2017"
+
+import logging
+
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class BaseAlphaSlider(qt.QSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a plot primitive (image, scatter or curve).
+
+ Internally, the slider stores its state as an integer between
+ 0 and 255. This is the value emitted by the :attr:`valueChanged`
+ signal.
+
+ The method :meth:`getAlpha` returns the corresponding opacity/alpha
+ as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`).
+
+ You must subclass this class and implement :meth:`getItem`.
+ """
+ sigAlphaChanged = qt.Signal(float)
+ """Emits the alpha value when the slider's value changes,
+ as a float between 0. and 1."""
+
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Parent plot widget
+ """
+ assert plot is not None
+ super(BaseAlphaSlider, self).__init__(parent)
+
+ self.plot = plot
+
+ self.setRange(0, 255)
+
+ # if already connected to an item, use its alpha as initial value
+ if self.getItem() is None:
+ self.setValue(255)
+ self.setEnabled(False)
+ else:
+ alpha = self.getItem().getAlpha()
+ self.setValue(round(255*alpha))
+
+ self.valueChanged.connect(self._valueChanged)
+
+ def getItem(self):
+ """You must implement this class to define which item
+ to work on. It must return an item that inherits
+ :class:`silx.gui.plot.items.core.AlphaMixIn`.
+
+ :return: Item on which to operate, or None
+ :rtype: :class:`silx.plot.items.Item`
+ """
+ raise NotImplementedError(
+ "BaseAlphaSlider must be subclassed to " +
+ "implement getItem()")
+
+ def getAlpha(self):
+ """Get the opacity, as a float between 0. and 1.
+
+ :return: Alpha value in [0., 1.]
+ :rtype: float
+ """
+ return self.value() / 255.
+
+ def _valueChanged(self, value):
+ self._updateItem()
+ self.sigAlphaChanged.emit(value / 255.)
+
+ def _updateItem(self):
+ """Update the item's alpha channel.
+ """
+ item = self.getItem()
+ if item is not None:
+ item.setAlpha(self.getAlpha())
+
+
+class ActiveImageAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of the **active image**.
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+
+ See documentation of :class:`BaseAlphaSlider`
+ """
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ """
+ super(ActiveImageAlphaSlider, self).__init__(parent, plot)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def getItem(self):
+ return self.plot.getActiveImage()
+
+ def _activeImageChanged(self, previous, new):
+ """Activate or deactivate slider depending on presence of a new
+ active image.
+ Apply transparency value to new active image.
+
+ :param previous: Legend of previous active image, or None
+ :param new: Legend of new active image, or None
+ """
+ if new is not None and not self.isEnabled():
+ self.setEnabled(True)
+ elif new is None and self.isEnabled():
+ self.setEnabled(False)
+
+ self._updateItem()
+
+
+class NamedItemAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an item (defined by its kind and legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str kind: Kind of item whose transparency is to be
+ controlled: "scatter", "image" or "curve".
+ :param str legend: Legend of item whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None,
+ kind=None, legend=None):
+ self._item_legend = legend
+ self._item_kind = kind
+
+ super(NamedItemAlphaSlider, self).__init__(parent, plot)
+
+ self._updateState()
+ plot.sigContentChanged.connect(self._onContentChanged)
+
+ def _onContentChanged(self, action, kind, legend):
+ if legend == self._item_legend and kind == self._item_kind:
+ if action == "add":
+ self.setEnabled(True)
+ elif action == "remove":
+ self.setEnabled(False)
+
+ def _updateState(self):
+ """Enable or disable widget based on item's availability."""
+ if self.getItem() is not None:
+ self.setEnabled(True)
+ else:
+ self.setEnabled(False)
+
+ def getItem(self):
+ """Return plot item currently associated to this widget (can be
+ a curve, an image, a scatter...)
+
+ :rtype: subclass of :class:`silx.gui.plot.items.Item`"""
+ if self._item_legend is None or self._item_kind is None:
+ return None
+ return self.plot._getItem(kind=self._item_kind,
+ legend=self._item_legend)
+
+ def setLegend(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getLegend(self):
+ """Return legend of the item currently controlled by this slider.
+
+ :return: Image legend associated to the slider
+ """
+ return self._item_kind
+
+ def setItemKind(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getItemKind(self):
+ """Return kind of the item currently controlled by this slider.
+
+ :return: Item kind ("image", "scatter"...)
+ :rtype: str on None
+ """
+ return self._item_kind
+
+
+class NamedImageAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an image (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of image whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="image", legend=legend)
+
+
+class NamedScatterAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a scatter (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of scatter whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="scatter", legend=legend)
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
new file mode 100644
index 0000000..93e3c36
--- /dev/null
+++ b/silx/gui/plot/ColorBar.py
@@ -0,0 +1,790 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module containing several widgets associated to a colormap.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/04/2017"
+
+
+import logging
+import numpy
+from ._utils import ticklayout
+from ._utils import clipColormapLogRange
+
+
+from .. import qt
+from silx.gui.plot import Colors
+
+_logger = logging.getLogger(__name__)
+
+
+class ColorBarWidget(qt.QWidget):
+ """Colorbar widget displaying a colormap
+
+ It uses a description of colormap as dict compatible with :class:`Plot`.
+
+ .. image:: img/linearColorbar.png
+ :width: 80px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> from silx.gui.plot import Plot2D
+ >>> from silx.gui.plot.ColorBar import ColorBarWidget
+
+ >>> plot = Plot2D() # Create a plot widget
+ >>> plot.show()
+
+ >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it
+ >>> colorbar.show()
+
+ Initializer parameters:
+
+ :param parent: See :class:`QWidget`
+ :param plot: PlotWidget the colorbar is attached to (optional)
+ :param str legend: the label to set to the colormap
+ """
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ super(ColorBarWidget, self).__init__(parent)
+ self._plot = None
+
+ self.__buildGUI()
+ self.setLegend(legend)
+ self.setPlot(plot)
+
+ def __buildGUI(self):
+ self.setLayout(qt.QHBoxLayout())
+
+ # create color scale widget
+ self._colorScale = ColorScaleBar(parent=self,
+ colormap=None)
+ self.layout().addWidget(self._colorScale)
+
+ # legend (is the right group)
+ self.legend = _VerticalLegend('', self)
+ self.layout().addWidget(self.legend)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Expanding)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ def getPlot(self):
+ """Returns the :class:`Plot` associated to this widget or None"""
+ return self._plot
+
+ def setPlot(self, plot):
+ """Associate a plot to the ColorBar
+
+ :param plot: the plot to associate with the colorbar. If None will remove
+ any connection with a previous plot.
+ """
+ # removing previous plot if any
+ if self._plot is not None:
+ self._plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+
+ # setting the new plot
+ self._plot = plot
+ if self._plot is not None:
+ self._plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged(self._plot.getActiveImage(just_legend=True))
+
+ def getColormap(self):
+ """Return the colormap displayed in the colorbar as a dict.
+
+ It returns None if no colormap is set.
+ See :class:`silx.gui.plot.Plot` documentation for the description of the colormap
+ dict description.
+ """
+ return self._colormap.copy()
+
+ def setColormap(self, colormap):
+ """Set the colormap to be displayed.
+
+ :param dict colormap: The colormap to apply on the ColorBarWidget
+ """
+ self._colormap = colormap
+ if self._colormap is None:
+ return
+
+ if self._colormap['normalization'] not in ('log', 'linear'):
+ raise ValueError('Wrong normalization %s' % self._colormap['normalization'])
+
+ if self._colormap['normalization'] is 'log':
+ if self._colormap['vmin'] < 1. or self._colormap['vmax'] < 1.:
+ _logger.warning('Log colormap with bound <= 1: changing bounds.')
+ clipColormapLogRange(colormap)
+
+ self.getColorScaleBar().setColormap(self._colormap)
+
+ def setLegend(self, legend):
+ """Set the legend displayed along the colorbar
+
+ :param str legend: The label
+ """
+ if legend is None or legend == "":
+ self.legend.hide()
+ self.legend.setText("")
+ else:
+ assert(type(legend) is str)
+ self.legend.show()
+ self.legend.setText(legend)
+
+ def getLegend(self):
+ """
+ Returns the legend displayed along the colorbar
+
+ :return: return the legend displayed along the colorbar
+ :rtype: str
+ """
+ return self.legend.getText()
+
+ def _activeImageChanged(self, legend):
+ """Handle plot active curve changed"""
+ if legend is None: # No active image, display default colormap
+ self._syncWithDefaultColormap()
+ return
+
+ # Sync with active image
+ image = self._plot.getActiveImage().getData(copy=False)
+
+ # RGB(A) image, display default colormap
+ if image.ndim != 2:
+ self._syncWithDefaultColormap()
+ return
+
+ # data image, sync with image colormap
+ # do we need the copy here : used in the case we are changing
+ # vmin and vmax but should have already be done by the plot
+ cmap = self._plot.getActiveImage().getColormap().copy()
+ if cmap['autoscale']:
+ if cmap['normalization'] == 'log':
+ data = image[
+ numpy.logical_and(image > 0, numpy.isfinite(image))]
+ else:
+ data = image[numpy.isfinite(image)]
+ cmap['vmin'], cmap['vmax'] = data.min(), data.max()
+
+ self.setColormap(cmap)
+
+ def _defaultColormapChanged(self):
+ """Handle plot default colormap changed"""
+ if self._plot.getActiveImage() is None:
+ # No active image, take default colormap update into account
+ self._syncWithDefaultColormap()
+
+ def _syncWithDefaultColormap(self):
+ """Update colorbar according to plot default colormap"""
+ self.setColormap(self._plot.getDefaultColormap())
+
+ def getColorScaleBar(self):
+ """
+
+ :return: return the :class:`ColorScaleBar` used to display ColorScale
+ and ticks"""
+ return self._colorScale
+
+
+class _VerticalLegend(qt.QLabel):
+ """Display vertically the given text
+ """
+ def __init__(self, text, parent=None):
+ """
+
+ :param text: the legend
+ :param parent: the Qt parent if any
+ """
+ qt.QLabel.__init__(self, text, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ painter.setFont(self.font())
+
+ painter.translate(0, self.rect().height())
+ painter.rotate(270)
+ newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width())
+
+ painter.drawText(newRect, qt.Qt.AlignHCenter, self.text())
+
+ fm = qt.QFontMetrics(self.font())
+ preferedHeight = fm.width(self.text())
+ preferedWidth = fm.height()
+ self.setFixedWidth(preferedWidth)
+ self.setMinimumHeight(preferedHeight)
+
+
+class ColorScaleBar(qt.QWidget):
+ """This class is making the composition of a :class:`_ColorScale` and a
+ :class:`_TickBar`.
+
+ It is the simplest widget displaying ticks and colormap gradient.
+
+ .. image:: img/colorScaleBar.png
+ :width: 150px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap={'name':'gray',
+ ... 'normalization':'log',
+ ... 'vmin':1,
+ ... 'vmax':100000,
+ ... 'autoscale':False
+ ... }
+ >>> colorscale = ColorScaleBar(parent=None,
+ ... colormap=colormap )
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param displayTicksValues: display the ticks value or only the '-'
+ """
+
+ _TEXT_MARGIN = 5
+ """The tick bar need a margin to display all labels at the correct place.
+ So the ColorScale should have the same margin in order for both to fit"""
+
+ _MIN_LIM_SCI_FORM = -1000
+ """Used for the min and max label to know when we should display it under
+ the scientific form"""
+
+ _MAX_LIM_SCI_FORM = 1000
+ """Used for the min and max label to know when we should display it under
+ the scientific form"""
+
+ def __init__(self, parent=None, colormap=None, displayTicksValues=True):
+ super(ColorScaleBar, self).__init__(parent)
+
+ self.minVal = None
+ """Value set to the _minLabel"""
+ self.maxVal = None
+ """Value set to the _maxLabel"""
+
+ self.setLayout(qt.QGridLayout())
+
+ # create the left side group (ColorScale)
+ self.colorScale = _ColorScale(colormap=colormap,
+ parent=self,
+ margin=ColorScaleBar._TEXT_MARGIN)
+
+ self.tickbar = _TickBar(vmin=colormap['vmin'] if colormap else 0.0,
+ vmax=colormap['vmax'] if colormap else 1.0,
+ norm=colormap['normalization'] if colormap else 'linear',
+ parent=self,
+ displayValues=displayTicksValues,
+ margin=ColorScaleBar._TEXT_MARGIN)
+
+ self.layout().addWidget(self.tickbar, 1, 0)
+ self.layout().addWidget(self.colorScale, 1, 1)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.layout().setSpacing(0)
+
+ # max label
+ self._maxLabel = qt.QLabel(str(1.0), parent=self)
+ self._maxLabel.setAlignment(qt.Qt.AlignHCenter)
+ self._maxLabel.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ self.layout().addWidget(self._maxLabel, 0, 1)
+
+ # min label
+ self._minLabel = qt.QLabel(str(0.0), parent=self)
+ self._minLabel.setAlignment(qt.Qt.AlignHCenter)
+ self._minLabel.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ self.layout().addWidget(self._minLabel, 2, 1)
+
+ def getTickBar(self):
+ """
+
+ :return: the instanciation of the :class:`_TickBar`
+ """
+ return self.tickbar
+
+ def getColorScale(self):
+ """
+
+ :return: the instanciation of the :class:`_ColorScale`
+ """
+ return self.colorScale
+
+ def setColormap(self, colormap):
+ """Set the new colormap to be displayed
+
+ :param dict colormap: the colormap to set
+ """
+ if colormap is not None:
+ self.colorScale.setColormap(colormap)
+
+ self.tickbar.update(vmin=colormap['vmin'],
+ vmax=colormap['vmax'],
+ norm=colormap['normalization'])
+
+ self._setMinMaxLabels(colormap['vmin'], colormap['vmax'])
+
+ def setMinMaxVisible(self, val=True):
+ """Change visibility of the min label and the max label
+
+ :param val: if True, set the labels visible, otherwise set it not visible
+ """
+ self._maxLabel.show() if val is True else self._maxLabel.hide()
+ self._minLabel.show() if val is True else self._minLabel.hide()
+
+ def _updateMinMax(self):
+ """Update the min and max label if we are in the case of the
+ configuration 'minMaxValueOnly'"""
+ if self._minLabel is not None and self._maxLabel is not None:
+ if self.minVal is not None:
+ if ColorScaleBar._MIN_LIM_SCI_FORM <= self.minVal <= ColorScaleBar._MAX_LIM_SCI_FORM:
+ self._minLabel.setText(str(self.minVal))
+ else:
+ self._minLabel.setText("{0:.0e}".format(self.minVal))
+ if self.maxVal is not None:
+ if ColorScaleBar._MIN_LIM_SCI_FORM <= self.maxVal <= ColorScaleBar._MAX_LIM_SCI_FORM:
+ self._maxLabel.setText(str(self.maxVal))
+ else:
+ self._maxLabel.setText("{0:.0e}".format(self.maxVal))
+
+ def _setMinMaxLabels(self, minVal, maxVal):
+ """Change the value of the min and max labels to be displayed.
+
+ :param minVal: the minimal value of the TickBar (not str)
+ :param maxVal: the maximal value of the TickBar (not str)
+ """
+ # bad hack to try to display has much information as possible
+ self.minVal = minVal
+ self.maxVal = maxVal
+ self._updateMinMax()
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self._updateMinMax()
+
+
+class _ColorScale(qt.QWidget):
+ """Widget displaying the colormap colorScale.
+
+ Show matching value between the gradient color (from the colormap) at mouse
+ position and value.
+
+ .. image:: img/colorScale.png
+ :width: 20px
+ :align: center
+
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap={'name':'viridis',
+ ... 'normalization':'log',
+ ... 'vmin':1,
+ ... 'vmax':100000,
+ ... 'autoscale':False
+ ... }
+ >>> colorscale = ColorScale(parent=None,
+ ... colormap=colormap)
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param int margin: the top and left margin to apply.
+
+ .. warning:: Value drawing will be
+ done at the center of ticks. So if no margin is done your values
+ drawing might not be fully done for extrems values.
+ """
+
+ _NB_CONTROL_POINTS = 256
+
+ def __init__(self, colormap, parent=None, margin=5):
+ qt.QWidget.__init__(self, parent)
+ self.colormap = None
+ self.setColormap(colormap)
+
+ self.setLayout(qt.QVBoxLayout())
+ self.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ # needed to get the mouse event without waiting for button click
+ self.setMouseTracking(True)
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ def setColormap(self, colormap):
+ """Set the new colormap to be displayed
+
+ :param dict colormap: the colormap to set
+ """
+ if colormap is None:
+ return
+
+ if colormap['normalization'] not in ('log', 'linear'):
+ raise ValueError("Unrecognized normalization, should be 'linear' or 'log'")
+
+ if colormap['normalization'] is 'log':
+ if not (colormap['vmin'] > 0 and colormap['vmax'] > 0):
+ raise ValueError('vmin and vmax should be positives')
+ self.colormap = colormap
+ self._computeColorPoints()
+
+ def _computeColorPoints(self):
+ """Compute the color points for the gradient
+ """
+ if self.colormap is None:
+ return
+
+ vmin = self.colormap['vmin']
+ vmax = self.colormap['vmax']
+ steps = (vmax - vmin)/float(_ColorScale._NB_CONTROL_POINTS)
+ self.ctrPoints = numpy.arange(vmin, vmax, steps)
+ self.colorsCtrPts = Colors.applyColormapToData(self.ctrPoints,
+ name=self.colormap['name'],
+ normalization='linear',
+ autoscale=self.colormap['autoscale'],
+ vmin=vmin,
+ vmax=vmax)
+
+ def paintEvent(self, event):
+ """"""
+ qt.QWidget.paintEvent(self, event)
+ if self.colormap is None:
+ return
+
+ vmin = self.colormap['vmin']
+ vmax = self.colormap['vmax']
+
+ painter = qt.QPainter(self)
+ gradient = qt.QLinearGradient(0, 0, 0, self.rect().height() - 2*self.margin)
+ for iPt, pt in enumerate(self.ctrPoints):
+ colormapPosition = 1 - (pt-vmin) / (vmax-vmin)
+ assert(colormapPosition >= 0.0)
+ assert(colormapPosition <= 1.0)
+ gradient.setColorAt(colormapPosition, qt.QColor(*(self.colorsCtrPts[iPt])))
+
+ painter.setBrush(gradient)
+ painter.drawRect(
+ qt.QRect(0, self.margin, self.width(), self.height() - 2.*self.margin))
+
+ def mouseMoveEvent(self, event):
+ """"""
+ self.setToolTip(str(self.getValueFromRelativePosition(self._getRelativePosition(event.y()))))
+ super(_ColorScale, self).mouseMoveEvent(event)
+
+ def _getRelativePosition(self, yPixel):
+ """yPixel : pixel position into _ColorScale widget reference
+ """
+ # widgets are bottom-top referencial but we display in top-bottom referential
+ return 1 - float(yPixel)/float(self.height() - 2*self.margin)
+
+ def getValueFromRelativePosition(self, value):
+ """Return the value in the colorMap from a relative position in the
+ ColorScaleBar (y)
+
+ :param value: float value in [0, 1]
+ :return: the value in [colormap['vmin'], colormap['vmax']]
+ """
+ value = max(0.0, value)
+ value = min(value, 1.0)
+ vmin = self.colormap['vmin']
+ vmax = self.colormap['vmax']
+ if self.colormap['normalization'] is 'linear':
+ return vmin + (vmax - vmin) * value
+ elif self.colormap['normalization'] is 'log':
+ rpos = (numpy.log10(vmax) - numpy.log10(vmin)) * value + numpy.log10(vmin)
+ return numpy.power(10., rpos)
+ else:
+ err = "normalization type (%s) is not managed by the _ColorScale Widget" % self.colormap['normalization']
+ raise ValueError(err)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a TickBar object.
+ This is needed since we can only paint on the viewport of the widget.
+ Didn't work with a simple setContentsMargins
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = margin
+
+
+class _TickBar(qt.QWidget):
+ """Bar grouping the ticks displayed
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> bar = TickBar(1, 1000, norm='log', parent=None, displayValues=True)
+ >>> bar.show()
+
+ .. image:: img/tickbar.png
+ :width: 40px
+ :align: center
+
+ :param int vmin: smaller value of the range of values
+ :param int vmax: higher value of the range of values
+ :param str norm: normalization type to be displayed. Valid values are
+ 'linear' and 'log'
+ :param parent: the Qt parent if any
+ :param bool displayValues: if True display the values close to the tick,
+ Otherwise only signal it by '-'
+ :param int nticks: the number of tick we want to display. Should be an
+ unsigned int ot None. If None, let the Tick bar find the optimal
+ number of ticks from the tick density.
+ :param int margin: margin to set on the top and bottom
+ """
+ _WIDTH_DISP_VAL = 45
+ """widget width when displayed with ticks labels"""
+ _WIDTH_NO_DISP_VAL = 10
+ """widget width when displayed without ticks labels"""
+ _FONT_SIZE = 10
+ """font size for ticks labels"""
+ _LINE_WIDTH = 10
+ """width of the line to mark a tick"""
+
+ DEFAULT_TICK_DENSITY = 0.015
+
+ def __init__(self, vmin, vmax, norm, parent=None, displayValues=True,
+ nticks=None, margin=5):
+ super(_TickBar, self).__init__(parent)
+ self._forcedDisplayType = None
+ self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY
+
+ self._vmin = vmin
+ self._vmax = vmax
+ # TODO : should be grouped into a global function, called by all
+ # logScale displayer to make sure we have the same behavior everywhere
+ if self._vmin < 1. or self._vmax < 1.:
+ _logger.warning(
+ 'Log colormap with bound <= 1: changing bounds.')
+ self._vmin, self._vmax = 1., 10.
+
+ self._norm = norm
+ self.displayValues = displayValues
+ self.setTicksNumber(nticks)
+ self.setMargin(margin)
+
+ self.setLayout(qt.QVBoxLayout())
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self._resetWidth()
+
+ def setTicksValuesVisible(self, val):
+ self.displayValues = val
+ self._resetWidth()
+
+ def _resetWidth(self):
+ self.width = _TickBar._WIDTH_DISP_VAL if self.displayValues else _TickBar._WIDTH_NO_DISP_VAL
+ self.setFixedWidth(self.width)
+
+ def update(self, vmin, vmax, norm):
+ self._vmin = vmin
+ self._vmax = vmax
+ self._norm = norm
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a _ColorScale object.
+ This is needed since we can only paint on the viewport of the widget
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = margin
+
+ def setTicksNumber(self, nticks):
+ """Set the number of ticks to display.
+
+ :param nticks: the number of tick to be display. Should be an
+ unsigned int ot None. If None, let the :class:`_TickBar` find the
+ optimal number of ticks from the tick density.
+ """
+ self._nticks = nticks
+ self.ticks = None
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setTicksDensity(self, density):
+ """If you let :class:`_TickBar` deal with the number of ticks
+ (nticks=None) then you can specify a ticks density to be displayed.
+ """
+ if density < 0.0:
+ raise ValueError('Density should be a positive value')
+ self.ticksDensity = density
+
+ def computeTicks(self):
+ """This function compute ticks values labels. It is called at each
+ update and each resize event.
+ Deal only with linear and log scale.
+ """
+ nticks = self._nticks
+ if nticks is None:
+ nticks = self._getOptimalNbTicks()
+
+ if self._norm == 'log':
+ self._computeTicksLog(nticks)
+ elif self._norm == 'linear':
+ self._computeTicksLin(nticks)
+ else:
+ err = 'TickBar - Wrong normalization %s' % self._norm
+ raise ValueError(err)
+ # update the form
+ font = qt.QFont()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+
+ self.form = self._getFormat(font)
+
+ def _computeTicksLog(self, nticks):
+ logMin = numpy.log10(self._vmin)
+ logMax = numpy.log10(self._vmax)
+ lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin,
+ logMax,
+ nticks)
+ self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing))
+ if spacing == 1:
+ self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks,
+ lowBound=numpy.power(10., lowBound),
+ highBound=numpy.power(10., highBound))
+ else:
+ self.subTicks = []
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self.computeTicks()
+
+ def _computeTicksLin(self, nticks):
+ _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin,
+ self._vmax,
+ nticks)
+
+ self.ticks = numpy.arange(_min, _max, _spacing)
+ self.subTicks = []
+
+ def _getOptimalNbTicks(self):
+ return max(2, int(round(self.ticksDensity * self.rect().height())))
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ font = painter.font()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+ painter.setFont(font)
+
+ # paint ticks
+ if self.ticks is not None:
+ for val in self.ticks:
+ self._paintTick(val, painter, majorTick=True)
+
+ # paint subticks
+ for val in self.subTicks:
+ self._paintTick(val, painter, majorTick=False)
+
+ qt.QWidget.paintEvent(self, event)
+
+ def _getRelativePosition(self, val):
+ """Return the relative position of val according to min and max value
+ """
+ if self._norm == 'linear':
+ return 1 - (val - self._vmin) / (self._vmax - self._vmin)
+ elif self._norm == 'log':
+ return 1 - (numpy.log10(val) - numpy.log10(self._vmin))/(numpy.log10(self._vmax) - numpy.log(self._vmin))
+ else:
+ raise ValueError('Norm is not recognized')
+
+ def _paintTick(self, val, painter, majorTick=True):
+ """
+
+ :param bool majorTick: if False will never draw text and will set a line
+ with a smaller width
+ """
+ fm = qt.QFontMetrics(painter.font())
+ viewportHeight = self.rect().height() - self.margin * 2
+ relativePos = self._getRelativePosition(val)
+ height = viewportHeight * relativePos
+ height += self.margin
+ lineWidth = _TickBar._LINE_WIDTH
+ if majorTick is False:
+ lineWidth /= 2
+
+ painter.drawLine(qt.QLine(self.width - lineWidth,
+ height,
+ self.width,
+ height))
+
+ if self.displayValues and majorTick is True:
+ painter.drawText(qt.QPoint(0.0, height + (fm.height() / 2)),
+ self.form.format(val))
+
+ def setDisplayType(self, disType):
+ """Set the type of display we want to set for ticks labels
+
+ :param str disType: The type of display we want to set. disType values
+ can be :
+
+ - 'std' for standard, meaning only a formatting on the number of
+ digits is done
+ - 'e' for scientific display
+ - None to let the _TickBar guess the best display for this kind of data.
+ """
+ if disType not in (None, 'std', 'e'):
+ raise ValueError("display type not recognized, value should be in (None, 'std', 'e'")
+ self._forcedDisplayType = disType
+
+ def _getStandardFormat(self):
+ return "{0:.%sf}" % self._nfrac
+
+ def _getFormat(self, font):
+ if self._forcedDisplayType is None:
+ return self._guessType(font)
+ elif self._forcedDisplayType is 'std':
+ return self._getStandardFormat()
+ elif self._forcedDisplayType is 'e':
+ return self._getScientificForm()
+ else:
+ err = 'Forced type for display %s is not recognized' % self._forcedDisplayType
+ raise ValueError(err)
+
+ def _getScientificForm(self):
+ return "{0:.0e}"
+
+ def _guessType(self, font):
+ """Try fo find the better format to display the tick's labels
+
+ :param QFont font: the font we want want to use durint the painting
+ """
+ assert(type(self._vmin) == type(self._vmax))
+ form = self._getStandardFormat()
+
+ fm = qt.QFontMetrics(font)
+ width = 0
+ for tick in self.ticks:
+ width = max(fm.width(form.format(tick)), width)
+
+ # if the length of the string are too long we are mooving to scientific
+ # display
+ if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
+ return self._getScientificForm()
+ else:
+ return form
diff --git a/silx/gui/plot/ColormapDialog.py b/silx/gui/plot/ColormapDialog.py
new file mode 100644
index 0000000..ad1425c
--- /dev/null
+++ b/silx/gui/plot/ColormapDialog.py
@@ -0,0 +1,506 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A QDialog widget to set-up the colormap.
+
+It uses a description of colormaps as dict compatible with :class:`Plot`.
+
+To run the following sample code, a QApplication must be initialized.
+
+Create the colormap dialog and set the colormap description and data range:
+
+>>> from silx.gui.plot.ColormapDialog import ColormapDialog
+
+>>> dialog = ColormapDialog()
+
+>>> dialog.setColormap(name='red', normalization='log',
+... autoscale=False, vmin=1., vmax=2.)
+>>> dialog.setDataRange(1., 100.) # This scale the width of the plot area
+>>> dialog.show()
+
+Get the colormap description (compatible with :class:`Plot`) from the dialog:
+
+>>> cmap = dialog.getColormap()
+>>> cmap['name']
+'red'
+
+It is also possible to display an histogram of the image in the dialog.
+This updates the data range with the range of the bins.
+
+>>> import numpy
+>>> image = numpy.random.normal(size=512 * 512).reshape(512, -1)
+>>> hist, bin_edges = numpy.histogram(image, bins=10)
+>>> dialog.setHistogram(hist, bin_edges)
+
+The updates of the colormap description are also available through the signal:
+:attr:`ColormapDialog.sigColormapChanged`.
+""" # noqa
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "29/03/2016"
+
+
+import logging
+
+import numpy
+
+from .. import qt
+from . import PlotWidget
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _FloatEdit(qt.QLineEdit):
+ """Field to edit a float value.
+
+ :param parent: See :class:`QLineEdit`
+ :param float value: The value to set the QLineEdit to.
+ """
+ def __init__(self, parent=None, value=None):
+ qt.QLineEdit.__init__(self, parent)
+ self.setValidator(qt.QDoubleValidator())
+ self.setAlignment(qt.Qt.AlignRight)
+ if value is not None:
+ self.setValue(value)
+
+ def value(self):
+ """Return the QLineEdit current value as a float."""
+ return float(self.text())
+
+ def setValue(self, value):
+ """Set the current value of the LineEdit
+
+ :param float value: The value to set the QLineEdit to.
+ """
+ self.setText('%g' % value)
+
+
+class ColormapDialog(qt.QDialog):
+ """A QDialog widget to set the colormap.
+
+ :param parent: See :class:`QDialog`
+ :param str title: The QDialog title
+ """
+
+ sigColormapChanged = qt.Signal(dict)
+ """Signal triggered when the colormap is changed.
+
+ It provides a dict describing the colormap to the slot.
+ This dict can be used with :class:`Plot`.
+ """
+
+ def __init__(self, parent=None, title="Colormap Dialog"):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle(title)
+
+ self._histogramData = None
+ self._dataRange = None
+ self._minMaxWasEdited = False
+
+ self._colormapList = (
+ 'gray', 'reversed gray',
+ 'temperature', 'red', 'green', 'blue', 'jet',
+ 'viridis', 'magma', 'inferno', 'plasma')
+
+ # Make the GUI
+ vLayout = qt.QVBoxLayout(self)
+
+ formWidget = qt.QWidget()
+ vLayout.addWidget(formWidget)
+ formLayout = qt.QFormLayout(formWidget)
+ formLayout.setContentsMargins(10, 10, 10, 10)
+ formLayout.setSpacing(0)
+
+ # Colormap row
+ self._comboBoxColormap = qt.QComboBox()
+ for cmap in self._colormapList:
+ # Capitalize first letters
+ cmap = ' '.join(w[0].upper() + w[1:] for w in cmap.split())
+ self._comboBoxColormap.addItem(cmap)
+ self._comboBoxColormap.activated[int].connect(self._notify)
+ formLayout.addRow('Colormap:', self._comboBoxColormap)
+
+ # Normalization row
+ self._normButtonLinear = qt.QRadioButton('Linear')
+ self._normButtonLinear.setChecked(True)
+ self._normButtonLog = qt.QRadioButton('Log')
+
+ normButtonGroup = qt.QButtonGroup(self)
+ normButtonGroup.setExclusive(True)
+ normButtonGroup.addButton(self._normButtonLinear)
+ normButtonGroup.addButton(self._normButtonLog)
+ normButtonGroup.buttonClicked[int].connect(self._notify)
+
+ normLayout = qt.QHBoxLayout()
+ normLayout.setContentsMargins(0, 0, 0, 0)
+ normLayout.setSpacing(10)
+ normLayout.addWidget(self._normButtonLinear)
+ normLayout.addWidget(self._normButtonLog)
+
+ formLayout.addRow('Normalization:', normLayout)
+
+ # Range row
+ self._rangeAutoscaleButton = qt.QCheckBox('Autoscale')
+ self._rangeAutoscaleButton.setChecked(True)
+ self._rangeAutoscaleButton.toggled.connect(self._autoscaleToggled)
+ self._rangeAutoscaleButton.clicked.connect(self._notify)
+ formLayout.addRow('Range:', self._rangeAutoscaleButton)
+
+ # Min row
+ self._minValue = _FloatEdit(value=1.)
+ self._minValue.setEnabled(False)
+ self._minValue.textEdited.connect(self._minMaxTextEdited)
+ self._minValue.editingFinished.connect(self._minEditingFinished)
+ formLayout.addRow('\tMin:', self._minValue)
+
+ # Max row
+ self._maxValue = _FloatEdit(value=10.)
+ self._maxValue.setEnabled(False)
+ self._maxValue.textEdited.connect(self._minMaxTextEdited)
+ self._maxValue.editingFinished.connect(self._maxEditingFinished)
+ formLayout.addRow('\tMax:', self._maxValue)
+
+ # Add plot for histogram
+ self._plotInit()
+ vLayout.addWidget(self._plot)
+
+ # Close button
+ buttonsWidget = qt.QWidget()
+ vLayout.addWidget(buttonsWidget)
+
+ buttonsLayout = qt.QHBoxLayout(buttonsWidget)
+
+ okButton = qt.QPushButton('OK')
+ okButton.clicked.connect(self.accept)
+ buttonsLayout.addWidget(okButton)
+
+ cancelButton = qt.QPushButton('Cancel')
+ cancelButton.clicked.connect(self.reject)
+ buttonsLayout.addWidget(cancelButton)
+
+ # colormap window can not be resized
+ self.setFixedSize(vLayout.minimumSize())
+
+ # Set the colormap to default values
+ self.setColormap(name='gray', normalization='linear',
+ autoscale=True, vmin=1., vmax=10.)
+
+ def _plotInit(self):
+ """Init the plot to display the range and the values"""
+ self._plot = PlotWidget()
+ self._plot.setDataMargins(yMinMargin=0.125, yMaxMargin=0.125)
+ self._plot.setGraphXLabel("Data Values")
+ self._plot.setGraphYLabel("")
+ self._plot.setInteractiveMode('select', zoomOnWheel=False)
+ self._plot.setActiveCurveHandling(False)
+ self._plot.setMinimumSize(qt.QSize(250, 200))
+ self._plot.sigPlotSignal.connect(self._plotSlot)
+ self._plot.hide()
+
+ self._plotUpdate()
+
+ def _plotUpdate(self, updateMarkers=True):
+ """Update the plot content
+
+ :param bool updateMarkers: True to update markers, False otherwith
+ """
+ dataRange = self.getDataRange()
+
+ if dataRange is None:
+ if self._plot.isVisibleTo(self):
+ self._plot.setVisible(False)
+ self.setFixedSize(self.layout().minimumSize())
+ return
+
+ if not self._plot.isVisibleTo(self):
+ self._plot.setVisible(True)
+ self.setFixedSize(self.layout().minimumSize())
+
+ dataMin, dataMax = dataRange
+ marge = (abs(dataMax) + abs(dataMin)) / 6.0
+ minmd = dataMin - marge
+ maxpd = dataMax + marge
+
+ start, end = self._minValue.value(), self._maxValue.value()
+
+ if start <= end:
+ x = [minmd, start, end, maxpd]
+ y = [0, 0, 1, 1]
+
+ else:
+ x = [minmd, end, start, maxpd]
+ y = [1, 1, 0, 0]
+
+ # Display the colormap on the side
+ # colormap = {'name': self.getColormap()['name'],
+ # 'normalization': self.getColormap()['normalization'],
+ # 'autoscale': True, 'vmin': 1., 'vmax': 256.}
+ # self._plot.addImage((1 + numpy.arange(256)).reshape(256, -1),
+ # xScale=(minmd - marge, marge),
+ # yScale=(1., 2./256.),
+ # legend='colormap',
+ # colormap=colormap)
+
+ self._plot.addCurve(x, y,
+ legend="ConstrainedCurve",
+ color='black',
+ symbol='o',
+ linestyle='-',
+ resetzoom=False)
+
+ draggable = not self._rangeAutoscaleButton.isChecked()
+
+ if updateMarkers:
+ self._plot.addXMarker(
+ self._minValue.value(),
+ legend='Min',
+ text='Min',
+ draggable=draggable,
+ color='blue',
+ constraint=self._plotMinMarkerConstraint)
+
+ self._plot.addXMarker(
+ self._maxValue.value(),
+ legend='Max',
+ text='Max',
+ draggable=draggable,
+ color='blue',
+ constraint=self._plotMaxMarkerConstraint)
+
+ self._plot.resetZoom()
+
+ def _plotMinMarkerConstraint(self, x, y):
+ """Constraint of the min marker"""
+ return min(x, self._maxValue.value()), y
+
+ def _plotMaxMarkerConstraint(self, x, y):
+ """Constraint of the max marker"""
+ return max(x, self._minValue.value()), y
+
+ def _plotSlot(self, event):
+ """Handle events from the plot"""
+ if event['event'] in ('markerMoving', 'markerMoved'):
+ value = float(str(event['xdata']))
+ if event['label'] == 'Min':
+ self._minValue.setValue(value)
+ elif event['label'] == 'Max':
+ self._maxValue.setValue(value)
+
+ # This will recreate the markers while interacting...
+ # It might break if marker interaction is changed
+ if event['event'] == 'markerMoved':
+ self._notify()
+ else:
+ self._plotUpdate(updateMarkers=False)
+
+ def getHistogram(self):
+ """Returns the counts and bin edges of the displayed histogram.
+
+ :return: (hist, bin_edges)
+ :rtype: 2-tuple of numpy arrays"""
+ if self._histogramData is None:
+ return None
+ else:
+ bins, counts = self._histogramData
+ return numpy.array(bins, copy=True), numpy.array(counts, copy=True)
+
+ def setHistogram(self, hist=None, bin_edges=None):
+ """Set the histogram to display.
+
+ This update the data range with the bounds of the bins.
+ See :meth:`setDataRange`.
+
+ :param hist: array-like of counts or None to hide histogram
+ :param bin_edges: array-like of bins edges or None to hide histogram
+ """
+ if hist is None or bin_edges is None:
+ self._histogramData = None
+ self._plot.remove(legend='Histogram', kind='curve')
+ self.setDataRange() # Remove data range
+
+ else:
+ hist = numpy.array(hist, copy=True)
+ bin_edges = numpy.array(bin_edges, copy=True)
+ self._histogramData = hist, bin_edges
+
+ # For now, draw the histogram as a curve
+ # using bin centers and normalised counts
+ bins_center = 0.5 * (bin_edges[:-1] + bin_edges[1:])
+ norm_hist = hist / max(hist)
+ self._plot.addCurve(bins_center, norm_hist,
+ legend="Histogram",
+ color='gray',
+ symbol='',
+ linestyle='-',
+ fill=True)
+
+ # Update the data range
+ self.setDataRange(bin_edges[0], bin_edges[-1])
+
+ def getDataRange(self):
+ """Returns the data range used for the histogram area.
+
+ :return: (dataMin, dataMax) or None if no data range is set
+ :rtype: 2-tuple of float
+ """
+ return self._dataRange
+
+ def setDataRange(self, min_=None, max_=None):
+ """Set the range of data to use for the range of the histogram area.
+
+ :param float min_: The min of the data or None to disable range.
+ :param float max_: The max of the data or None to disable range.
+ """
+ if min_ is None or max_ is None:
+ self._dataRange = None
+ self._plotUpdate()
+
+ else:
+ min_, max_ = float(min_), float(max_)
+ assert min_ <= max_
+ self._dataRange = min_, max_
+ if self._rangeAutoscaleButton.isChecked():
+ self._minValue.setValue(min_)
+ self._maxValue.setValue(max_)
+ self._notify()
+ else:
+ self._plotUpdate()
+
+ def getColormap(self):
+ """Return the colormap description as a dict.
+
+ See :class:`Plot` for documentation on the colormap dict.
+ """
+ isNormLinear = self._normButtonLinear.isChecked()
+ colormap = {
+ 'name': str(self._comboBoxColormap.currentText()).lower(),
+ 'normalization': 'linear' if isNormLinear else 'log',
+ 'autoscale': self._rangeAutoscaleButton.isChecked(),
+ 'vmin': self._minValue.value(),
+ 'vmax': self._maxValue.value()}
+ return colormap
+
+ def setColormap(self, name=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the colormap description
+
+ If some arguments are not provided, the current values are used.
+
+ :param str name: The name of the colormap
+ :param str normalization: 'linear' or 'log'
+ :param bool autoscale: Toggle colormap range autoscale
+ :param float vmin: The min value, ignored if autoscale is True
+ :param float vmax: The max value, ignored if autoscale is True
+ """
+ if name is not None:
+ assert name in self._colormapList
+ index = self._colormapList.index(name)
+ self._comboBoxColormap.setCurrentIndex(index)
+
+ if normalization is not None:
+ assert normalization in ('linear', 'log')
+ self._normButtonLinear.setChecked(normalization == 'linear')
+ self._normButtonLog.setChecked(normalization == 'log')
+
+ if vmin is not None:
+ self._minValue.setValue(vmin)
+
+ if vmax is not None:
+ self._maxValue.setValue(vmax)
+
+ if autoscale is not None:
+ self._rangeAutoscaleButton.setChecked(autoscale)
+ if autoscale:
+ dataRange = self.getDataRange()
+ if dataRange is not None:
+ self._minValue.setValue(dataRange[0])
+ self._maxValue.setValue(dataRange[1])
+
+ # Do it once for all the changes
+ self._notify()
+
+ def _notify(self, *args, **kwargs):
+ """Emit the signal for colormap change"""
+ self._plotUpdate()
+ self.sigColormapChanged.emit(self.getColormap())
+
+ def _autoscaleToggled(self, checked):
+ """Handle autoscale changes by enabling/disabling min/max fields"""
+ self._minValue.setEnabled(not checked)
+ self._maxValue.setEnabled(not checked)
+ if checked:
+ dataRange = self.getDataRange()
+ if dataRange is not None:
+ self._minValue.setValue(dataRange[0])
+ self._maxValue.setValue(dataRange[1])
+
+ def _minMaxTextEdited(self, text):
+ """Handle _minValue and _maxValue textEdited signal"""
+ self._minMaxWasEdited = True
+
+ def _minEditingFinished(self):
+ """Handle _minValue editingFinished signal
+
+ Together with :meth:`_minMaxTextEdited`, this avoids to notify
+ colormap change when the min and max value where not edited.
+ """
+ if self._minMaxWasEdited:
+ self._minMaxWasEdited = False
+
+ # Fix start value
+ if self._minValue.value() > self._maxValue.value():
+ self._minValue.setValue(self._maxValue.value())
+ self._notify()
+
+ def _maxEditingFinished(self):
+ """Handle _maxValue editingFinished signal
+
+ Together with :meth:`_minMaxTextEdited`, this avoids to notify
+ colormap change when the min and max value where not edited.
+ """
+ if self._minMaxWasEdited:
+ self._minMaxWasEdited = False
+
+ # Fix end value
+ if self._minValue.value() > self._maxValue.value():
+ self._maxValue.setValue(self._minValue.value())
+ self._notify()
+
+ def keyPressEvent(self, event):
+ """Override key handling.
+
+ It disables leaving the dialog when editing a text field.
+ """
+ if event.key() == qt.Qt.Key_Enter and (self._minValue.hasFocus() or
+ self._maxValue.hasFocus()):
+ # Bypass QDialog keyPressEvent
+ # To avoid leaving the dialog when pressing enter on a text field
+ super(qt.QDialog, self).keyPressEvent(event)
+ else:
+ # Use QDialog keyPressEvent
+ super(ColormapDialog, self).keyPressEvent(event)
diff --git a/silx/gui/plot/Colors.py b/silx/gui/plot/Colors.py
new file mode 100644
index 0000000..7a3cd97
--- /dev/null
+++ b/silx/gui/plot/Colors.py
@@ -0,0 +1,359 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Color conversion function, color dictionary and colormap tools."""
+
+__authors__ = ["V.A. Sole", "T. VINCENT"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+import logging
+
+import numpy
+
+import matplotlib
+import matplotlib.colors
+import matplotlib.cm
+
+from . import MPLColormap
+
+
+_logger = logging.getLogger(__name__)
+
+
+COLORDICT = {}
+"""Dictionary of common colors."""
+
+COLORDICT['b'] = COLORDICT['blue'] = '#0000ff'
+COLORDICT['r'] = COLORDICT['red'] = '#ff0000'
+COLORDICT['g'] = COLORDICT['green'] = '#00ff00'
+COLORDICT['k'] = COLORDICT['black'] = '#000000'
+COLORDICT['w'] = COLORDICT['white'] = '#ffffff'
+COLORDICT['pink'] = '#ff66ff'
+COLORDICT['brown'] = '#a52a2a'
+COLORDICT['orange'] = '#ff9900'
+COLORDICT['violet'] = '#6600ff'
+COLORDICT['gray'] = COLORDICT['grey'] = '#a0a0a4'
+# COLORDICT['darkGray'] = COLORDICT['darkGrey'] = '#808080'
+# COLORDICT['lightGray'] = COLORDICT['lightGrey'] = '#c0c0c0'
+COLORDICT['y'] = COLORDICT['yellow'] = '#ffff00'
+COLORDICT['m'] = COLORDICT['magenta'] = '#ff00ff'
+COLORDICT['c'] = COLORDICT['cyan'] = '#00ffff'
+COLORDICT['darkBlue'] = '#000080'
+COLORDICT['darkRed'] = '#800000'
+COLORDICT['darkGreen'] = '#008000'
+COLORDICT['darkBrown'] = '#660000'
+COLORDICT['darkCyan'] = '#008080'
+COLORDICT['darkYellow'] = '#808000'
+COLORDICT['darkMagenta'] = '#800080'
+
+
+def rgba(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to (R, G, B, A)
+
+ It also convert RGB(A) values from uint8 to float in [0, 1] and
+ accept a QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ if colorDict is None:
+ colorDict = COLORDICT
+
+ if hasattr(color, 'getRgbF'): # QColor support
+ color = color.getRgbF()
+
+ values = numpy.asarray(color).ravel()
+
+ if values.dtype.kind in 'iuf': # integer or float
+ # Color is an array
+ assert len(values) in (3, 4)
+
+ # Convert from integers in [0, 255] to float in [0, 1]
+ if values.dtype.kind in 'iu':
+ values = values / 255.
+
+ # Clip to [0, 1]
+ values[values < 0.] = 0.
+ values[values > 1.] = 1.
+
+ if len(values) == 3:
+ return values[0], values[1], values[2], 1.
+ else:
+ return tuple(values)
+
+ # We assume color is a string
+ if not color.startswith('#'):
+ color = colorDict[color]
+
+ assert len(color) in (7, 9) and color[0] == '#'
+ r = int(color[1:3], 16) / 255.
+ g = int(color[3:5], 16) / 255.
+ b = int(color[5:7], 16) / 255.
+ a = int(color[7:9], 16) / 255. if len(color) == 9 else 1.
+ return r, g, b, a
+
+
+_COLORMAP_CURSOR_COLORS = {
+ 'gray': 'pink',
+ 'reversed gray': 'pink',
+ 'temperature': 'pink',
+ 'red': 'green',
+ 'green': 'pink',
+ 'blue': 'yellow',
+ 'jet': 'pink',
+ 'viridis': 'pink',
+ 'magma': 'green',
+ 'inferno': 'green',
+ 'plasma': 'green',
+}
+
+
+def cursorColorForColormap(colormapName):
+ """Get a color suitable for overlay over a colormap.
+
+ :param str colormapName: The name of the colormap.
+ :return: Name of the color.
+ :rtype: str
+ """
+ return _COLORMAP_CURSOR_COLORS.get(colormapName, 'black')
+
+
+_CMAPS = {} # Store additional colormaps
+
+
+def getMPLColormap(name):
+ """Returns matplotlib colormap corresponding to given name
+
+ :param str name: The name of the colormap
+ :return: The corresponding colormap
+ :rtype: matplolib.colors.Colormap
+ """
+ if not _CMAPS: # Lazy initialization of own colormaps
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ _CMAPS['red'] = matplotlib.colors.LinearSegmentedColormap(
+ 'red', cdict, 256)
+
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ _CMAPS['green'] = matplotlib.colors.LinearSegmentedColormap(
+ 'green', cdict, 256)
+
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0))}
+ _CMAPS['blue'] = matplotlib.colors.LinearSegmentedColormap(
+ 'blue', cdict, 256)
+
+ # Temperature as defined in spslut
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (0.5, 0.0, 0.0),
+ (0.75, 1.0, 1.0),
+ (1.0, 1.0, 1.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (0.25, 1.0, 1.0),
+ (0.75, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 1.0, 1.0),
+ (0.25, 1.0, 1.0),
+ (0.5, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ # but limited to 256 colors for a faster display (of the colorbar)
+ _CMAPS['temperature'] = \
+ matplotlib.colors.LinearSegmentedColormap(
+ 'temperature', cdict, 256)
+
+ # reversed gray
+ cdict = {'red': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0))}
+
+ _CMAPS['reversed gray'] = \
+ matplotlib.colors.LinearSegmentedColormap(
+ 'yerg', cdict, 256)
+
+ if name in _CMAPS:
+ return _CMAPS[name]
+ elif hasattr(MPLColormap, name): # viridis and sister colormaps
+ return getattr(MPLColormap, name)
+ else:
+ # matplotlib built-in
+ return matplotlib.cm.get_cmap(name)
+
+
+def getMPLScalarMappable(colormap, data=None):
+ """Returns matplotlib ScalarMappable corresponding to colormap
+
+ :param dict colormap: The colormap to convert
+ :param numpy.ndarray data:
+ The data on which the colormap is applied.
+ If provided, it is used to compute autoscale.
+ :return: matplotlib object corresponding to colormap
+ :rtype: matplotlib.cm.ScalarMappable
+ """
+ assert colormap is not None
+
+ if colormap['name'] is not None:
+ cmap = getMPLColormap(colormap['name'])
+
+ else: # No name, use custom colors
+ if 'colors' not in colormap:
+ raise ValueError(
+ 'addImage: colormap no name nor list of colors.')
+ colors = numpy.array(colormap['colors'], copy=True)
+ assert len(colors.shape) == 2
+ assert colors.shape[-1] in (3, 4)
+ if colors.dtype == numpy.uint8:
+ # Convert to float in [0., 1.]
+ colors = colors.astype(numpy.float32) / 255.
+ cmap = matplotlib.colors.ListedColormap(colors)
+
+ if colormap['normalization'].startswith('log'):
+ vmin, vmax = None, None
+ if not colormap['autoscale']:
+ if colormap['vmin'] > 0.:
+ vmin = colormap['vmin']
+ if colormap['vmax'] > 0.:
+ vmax = colormap['vmax']
+
+ if vmin is None or vmax is None:
+ _logger.warning('Log colormap with negative bounds, ' +
+ 'changing bounds to positive ones.')
+ elif vmin > vmax:
+ _logger.warning('Colormap bounds are inverted.')
+ vmin, vmax = vmax, vmin
+
+ # Set unset/negative bounds to positive bounds
+ if (vmin is None or vmax is None) and data is not None:
+ finiteData = data[numpy.isfinite(data)]
+ posData = finiteData[finiteData > 0]
+ if vmax is None:
+ # 1. as an ultimate fallback
+ vmax = posData.max() if posData.size > 0 else 1.
+ if vmin is None:
+ vmin = posData.min() if posData.size > 0 else vmax
+ if vmin > vmax:
+ vmin = vmax
+
+ norm = matplotlib.colors.LogNorm(vmin, vmax)
+
+ else: # Linear normalization
+ if colormap['autoscale']:
+ if data is None:
+ vmin, vmax = None, None
+ else:
+ finiteData = data[numpy.isfinite(data)]
+ vmin = finiteData.min()
+ vmax = finiteData.max()
+ else:
+ vmin = colormap['vmin']
+ vmax = colormap['vmax']
+ if vmin > vmax:
+ _logger.warning('Colormap bounds are inverted.')
+ vmin, vmax = vmax, vmin
+
+ norm = matplotlib.colors.Normalize(vmin, vmax)
+
+ return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
+
+
+def applyColormapToData(data,
+ name='gray',
+ normalization='linear',
+ autoscale=True,
+ vmin=0.,
+ vmax=1.,
+ colors=None):
+ """Apply a colormap to the data and returns the RGBA image
+
+ This supports data of any dimensions (not only of dimension 2).
+ The returned array will have one more dimension (with 4 entries)
+ than the input data to store the RGBA channels
+ corresponding to each bin in the array.
+
+ :param numpy.ndarray data: The data to convert.
+ :param str name: Name of the colormap (default: 'gray').
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use data min/max (True, default)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ :return: The computed RGBA image
+ :rtype: numpy.ndarray of uint8
+ """
+ # Debian 7 specific support
+ # No transparent colormap with matplotlib < 1.2.0
+ # Add support for transparent colormap for uint8 data with
+ # colormap with 256 colors, linear norm, [0, 255] range
+ if matplotlib.__version__ < '1.2.0':
+ if name is None and colors is not None:
+ colors = numpy.array(colors, copy=False)
+ if (colors.shape[-1] == 4 and
+ not numpy.all(numpy.equal(colors[3], 255))):
+ # This is a transparent colormap
+ if (colors.shape == (256, 4) and
+ normalization == 'linear' and
+ not autoscale and
+ vmin == 0 and vmax == 255 and
+ data.dtype == numpy.uint8):
+ # Supported case, convert data to RGBA
+ return colors[data.reshape(-1)].reshape(
+ data.shape + (4,))
+ else:
+ _logger.warning(
+ 'matplotlib %s does not support transparent '
+ 'colormap.', matplotlib.__version__)
+
+ colormap = dict(name=name,
+ normalization=normalization,
+ autoscale=autoscale,
+ vmin=vmin,
+ vmax=vmax,
+ colors=colors)
+ scalarMappable = getMPLScalarMappable(colormap, data)
+ rgbaImage = scalarMappable.to_rgba(data, bytes=True)
+
+ return rgbaImage
diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py
new file mode 100644
index 0000000..13c3de0
--- /dev/null
+++ b/silx/gui/plot/CurvesROIWidget.py
@@ -0,0 +1,975 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow.
+
+This widget is meant to work with :class:`PlotWindow`.
+
+ROI are defined by :
+
+- A name (`ROI` column)
+- A type. The type is the label of the x axis.
+ This can be used to apply or not some ROI to a curve and do some post processing.
+- The x coordinate of the left limit (`from` column)
+- The x coordinate of the right limit (`to` column)
+- Raw counts: integral of the curve between the
+ min ROI point and the max ROI point to the y = 0 line
+
+ .. image:: img/rawCounts.png
+
+- Net counts: the integral of the curve between the
+ min ROI point and the max ROI point to [ROI min point, ROI max point] segment
+
+ .. image:: img/netCounts.png
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+from collections import OrderedDict
+
+import logging
+import os
+import sys
+
+import numpy
+
+from silx.io import dictdump
+from .. import icons, qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurvesROIWidget(qt.QWidget):
+ """Widget displaying a table of ROI information.
+
+ :param parent: See :class:`QWidget`
+ :param str name: The title of this widget
+ """
+
+ sigROIWidgetSignal = qt.Signal(object)
+ """Signal of ROIs modifications.
+
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+
+ Type of events:
+
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
+ """
+
+ def __init__(self, parent=None, name=None):
+ super(CurvesROIWidget, self).__init__(parent)
+ if name is not None:
+ self.setWindowTitle(name)
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ ##############
+ self.headerLabel = qt.QLabel(self)
+ self.headerLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.setHeader()
+ layout.addWidget(self.headerLabel)
+ ##############
+ self.roiTable = ROITable(self)
+ rheight = self.roiTable.horizontalHeader().sizeHint().height()
+ self.roiTable.setMinimumHeight(4 * rheight)
+ self.fillFromROIDict = self.roiTable.fillFromROIDict
+ self.getROIListAndDict = self.roiTable.getROIListAndDict
+ layout.addWidget(self.roiTable)
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ #################
+
+ hbox = qt.QWidget(self)
+ hboxlayout = qt.QHBoxLayout(hbox)
+ hboxlayout.setContentsMargins(0, 0, 0, 0)
+ hboxlayout.setSpacing(0)
+
+ hboxlayout.addStretch(0)
+
+ self.addButton = qt.QPushButton(hbox)
+ self.addButton.setText("Add ROI")
+ self.addButton.setToolTip('Create a new ROI')
+ self.delButton = qt.QPushButton(hbox)
+ self.delButton.setText("Delete ROI")
+ self.addButton.setToolTip('Remove the selected ROI')
+ self.resetButton = qt.QPushButton(hbox)
+ self.resetButton.setText("Reset")
+ self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI')
+
+ hboxlayout.addWidget(self.addButton)
+ hboxlayout.addWidget(self.delButton)
+ hboxlayout.addWidget(self.resetButton)
+
+ hboxlayout.addStretch(0)
+
+ self.loadButton = qt.QPushButton(hbox)
+ self.loadButton.setText("Load")
+ self.loadButton.setToolTip('Load ROIs from a .ini file')
+ self.saveButton = qt.QPushButton(hbox)
+ self.saveButton.setText("Save")
+ self.loadButton.setToolTip('Save ROIs to a .ini file')
+ hboxlayout.addWidget(self.loadButton)
+ hboxlayout.addWidget(self.saveButton)
+ layout.setStretchFactor(self.headerLabel, 0)
+ layout.setStretchFactor(self.roiTable, 1)
+ layout.setStretchFactor(hbox, 0)
+
+ layout.addWidget(hbox)
+
+ self.addButton.clicked.connect(self._add)
+ self.delButton.clicked.connect(self._del)
+ self.resetButton.clicked.connect(self._reset)
+
+ self.loadButton.clicked.connect(self._load)
+ self.saveButton.clicked.connect(self._save)
+ self.roiTable.sigROITableSignal.connect(self._forward)
+
+ @property
+ def roiFileDir(self):
+ """The directory from which to load/save ROI from/to files."""
+ if not os.path.isdir(self._roiFileDir):
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ return self._roiFileDir
+
+ @roiFileDir.setter
+ def roiFileDir(self, roiFileDir):
+ self._roiFileDir = str(roiFileDir)
+
+ def setRois(self, roidict, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``roidict`` if provided as an OrderedDict.
+ """
+ if order is None or order.lower() == "none":
+ roilist = list(roidict.keys())
+ else:
+ assert order in ["from", "to", "type"]
+ roilist = sorted(roidict.keys(),
+ key=lambda roi_name: roidict[roi_name].get(order))
+
+ return self.roiTable.fillFromROIDict(roilist, roidict)
+
+ def getRois(self, order=None):
+ """Return the currently defined ROIs, as an ordered dict.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
+ roilist, roidict = self.roiTable.getROIListAndDict()
+ if order is None or order.lower() == "none":
+ ordered_roilist = roilist
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(roidict.keys(),
+ key=lambda roi_name: roidict[roi_name].get(order))
+
+ return OrderedDict([(name, roidict[name]) for name in ordered_roilist])
+
+ def _add(self):
+ """Add button clicked handler"""
+ ddict = {}
+ ddict['event'] = "AddROI"
+ roilist, roidict = self.roiTable.getROIListAndDict()
+ ddict['roilist'] = roilist
+ ddict['roidict'] = roidict
+ self.sigROIWidgetSignal.emit(ddict)
+
+ def _del(self):
+ """Delete button clicked handler"""
+ row = self.roiTable.currentRow()
+ if row >= 0:
+ index = self.roiTable.labels.index('Type')
+ text = str(self.roiTable.item(row, index).text())
+ if text.upper() != 'DEFAULT':
+ index = self.roiTable.labels.index('ROI')
+ key = str(self.roiTable.item(row, index).text())
+ else:
+ # This is to prevent deleting ICR ROI, that is
+ # usually initialized as "Default" type.
+ return
+ roilist, roidict = self.roiTable.getROIListAndDict()
+ row = roilist.index(key)
+ del roilist[row]
+ del roidict[key]
+ if len(roilist) > 0:
+ currentroi = roilist[0]
+ else:
+ currentroi = None
+
+ self.roiTable.fillFromROIDict(roilist=roilist,
+ roidict=roidict,
+ currentroi=currentroi)
+ ddict = {}
+ ddict['event'] = "DelROI"
+ ddict['roilist'] = roilist
+ ddict['roidict'] = roidict
+ self.sigROIWidgetSignal.emit(ddict)
+
+ def _forward(self, ddict):
+ """Broadcast events from ROITable signal"""
+ self.sigROIWidgetSignal.emit(ddict)
+
+ def _reset(self):
+ """Reset button clicked handler"""
+ ddict = {}
+ ddict['event'] = "ResetROI"
+ roilist0, roidict0 = self.roiTable.getROIListAndDict()
+ index = 0
+ for key in roilist0:
+ if roidict0[key]['type'].upper() == 'DEFAULT':
+ index = roilist0.index(key)
+ break
+ roilist = []
+ roidict = {}
+ if len(roilist0):
+ roilist.append(roilist0[index])
+ roidict[roilist[0]] = {}
+ roidict[roilist[0]].update(roidict0[roilist[0]])
+ self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict)
+ ddict['roilist'] = roilist
+ ddict['roidict'] = roidict
+ self.sigROIWidgetSignal.emit(ddict)
+
+ def _load(self):
+ """Load button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(
+ ['INI File *.ini', 'JSON File *.json', 'All *.*'])
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494
+ outputFile = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.load(outputFile)
+
+ def load(self, filename):
+ """Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ rois = dictdump.load(filename)
+ currentROI = None
+ if self.roiTable.rowCount():
+ item = self.roiTable.item(self.roiTable.currentRow(), 0)
+ if item is not None:
+ currentROI = str(item.text())
+
+ # Remove rawcounts and netcounts from ROIs
+ for roi in rois['ROI']['roidict'].values():
+ roi.pop('rawcounts', None)
+ roi.pop('netcounts', None)
+
+ self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'],
+ roidict=rois['ROI']['roidict'],
+ currentroi=currentROI)
+
+ roilist, roidict = self.roiTable.getROIListAndDict()
+ event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict}
+ self.sigROIWidgetSignal.emit(event)
+
+ def _save(self):
+ """Save button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(['INI File *.ini', 'JSON File *.json'])
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ outputFile = dialog.selectedFiles()[0]
+ extension = '.' + dialog.selectedNameFilter().split('.')[-1]
+ dialog.close()
+
+ if not outputFile.endswith(extension):
+ outputFile += extension
+
+ if os.path.exists(outputFile):
+ try:
+ os.remove(outputFile)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec_()
+ return
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.save(outputFile)
+
+ def save(self, filename):
+ """Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist, roidict = self.roiTable.getROIListAndDict()
+ datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ dictdump.dump(datadict, filename)
+
+ def setHeader(self, text='ROIs'):
+ """Set the header text of this widget"""
+ self.headerLabel.setText("<b>%s<\b>" % text)
+
+
+class ROITable(qt.QTableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
+ """
+
+ sigROITableSignal = qt.Signal(object)
+ """Signal of ROI table modifications.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(ROITable, self).__init__(*args, **kwargs)
+ self.setRowCount(1)
+ self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts'
+ self.setColumnCount(len(self.labels))
+ self.setSortingEnabled(False)
+
+ for index, label in enumerate(self.labels):
+ item = self.horizontalHeaderItem(index)
+ if item is None:
+ item = qt.QTableWidgetItem(label,
+ qt.QTableWidgetItem.Type)
+ item.setText(label)
+ self.setHorizontalHeaderItem(index, item)
+
+ self.roidict = {}
+ self.roilist = []
+
+ self.building = False
+ self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict)
+
+ self.cellClicked[(int, int)].connect(self._cellClickedSlot)
+ self.cellChanged[(int, int)].connect(self._cellChangedSlot)
+ verticalHeader = self.verticalHeader()
+ verticalHeader.sectionClicked[int].connect(self._rowChangedSlot)
+
+ self.__setTooltip()
+
+ def __setTooltip(self):
+ assert(self.labels[0] == 'ROI')
+ self.horizontalHeaderItem(0).setToolTip('Region of interest identifier')
+ assert(self.labels[1] == 'Type')
+ self.horizontalHeaderItem(1).setToolTip('Type of the ROI')
+ assert(self.labels[2] == 'From')
+ self.horizontalHeaderItem(2).setToolTip('X-value of the min point')
+ assert(self.labels[3] == 'To')
+ self.horizontalHeaderItem(3).setToolTip('X-value of the max point')
+ assert(self.labels[4] == 'Raw Counts')
+ self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \
+ between y=0 and the selected curve')
+ assert(self.labels[5] == 'Net Counts')
+ self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \
+ between the segment [maxPt, minPt] and the selected curve')
+
+ def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
+ """Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
+
+ The ROI names must match an existing dictionary key.
+ The name list is used to provide an order for the ROIs.
+
+ The dictionary's values are sub-dictionaries containing 3
+ mandatory fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+ :param roilist: List of ROI names (keys of roidict)
+ :type roilist: List
+ :param dict roidict: Dict of ROI information
+ :param currentroi: Name of the selected ROI or None (no selection)
+ """
+ if roidict is None:
+ roidict = {}
+
+ self.building = True
+ line0 = 0
+ self.roilist = []
+ self.roidict = {}
+ for key in roilist:
+ if key in roidict.keys():
+ roi = roidict[key]
+ self.roilist.append(key)
+ self.roidict[key] = {}
+ self.roidict[key].update(roi)
+ line0 = line0 + 1
+ nlines = self.rowCount()
+ if (line0 > nlines):
+ self.setRowCount(line0)
+ line = line0 - 1
+ self.roidict[key]['line'] = line
+ ROI = key
+ roitype = "%s" % roi['type']
+ fromdata = "%6g" % (roi['from'])
+ todata = "%6g" % (roi['to'])
+ if 'rawcounts' in roi:
+ rawcounts = "%6g" % (roi['rawcounts'])
+ else:
+ rawcounts = " ?????? "
+ if 'netcounts' in roi:
+ netcounts = "%6g" % (roi['netcounts'])
+ else:
+ netcounts = " ?????? "
+ fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts]
+ col = 0
+ for field in fields:
+ key2 = self.item(line, col)
+ if key2 is None:
+ key2 = qt.QTableWidgetItem(field,
+ qt.QTableWidgetItem.Type)
+ self.setItem(line, col, key2)
+ else:
+ key2.setText(field)
+ if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'):
+ key2.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled)
+ else:
+ if col in [0, 2, 3]:
+ key2.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ else:
+ key2.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled)
+ col = col + 1
+ self.setRowCount(line0)
+ i = 0
+ for _label in self.labels:
+ self.resizeColumnToContents(i)
+ i = i + 1
+ self.sortByColumn(2, qt.Qt.AscendingOrder)
+ for i in range(len(self.roilist)):
+ key = str(self.item(i, 0).text())
+ self.roilist[i] = key
+ self.roidict[key]['line'] = i
+ if len(self.roilist) == 1:
+ self.selectRow(0)
+ else:
+ if currentroi in self.roidict.keys():
+ self.selectRow(self.roidict[currentroi]['line'])
+ _logger.debug("Qt4 ensureCellVisible to be implemented")
+ self.building = False
+
+ def getROIListAndDict(self):
+ """Return the currently defined ROIs, as a 2-tuple
+ ``(roiList, roiDict)``
+
+ ``roiList`` is a list of ROI names.
+ ``roiDict`` is a dictionary of ROI info.
+
+ The ROI names must match an existing dictionary key.
+ The name list is used to provide an order for the ROIs.
+
+ The dictionary's values are sub-dictionaries containing 3
+ fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :return: ordered dict as a tuple of (list of ROI names, dict of info)
+ """
+ return self.roilist, self.roidict
+
+ def _cellClickedSlot(self, *var, **kw):
+ # selection changed event, get the current selection
+ row = self.currentRow()
+ col = self.currentColumn()
+ if row >= 0 and row < len(self.roilist):
+ item = self.item(row, 0)
+ text = '' if item is None else str(item.text())
+ self.roilist[row] = text
+ self._emitSelectionChangedSignal(row, col)
+
+ def _rowChangedSlot(self, row):
+ self._emitSelectionChangedSignal(row, 0)
+
+ def _cellChangedSlot(self, row, col):
+ _logger.debug("_cellChangedSlot(%d, %d)", row, col)
+ if self.building:
+ return
+ if col == 0:
+ self.nameSlot(row, col)
+ else:
+ self._valueChanged(row, col)
+
+ def _valueChanged(self, row, col):
+ if col not in [2, 3]:
+ return
+ item = self.item(row, col)
+ if item is None:
+ return
+ text = str(item.text())
+ try:
+ value = float(text)
+ except:
+ return
+ if row >= len(self.roilist):
+ _logger.debug("deleting???")
+ return
+ item = self.item(row, 0)
+ if item is None:
+ text = ""
+ else:
+ text = str(item.text())
+ if not len(text):
+ return
+ if col == 2:
+ self.roidict[text]['from'] = value
+ elif col == 3:
+ self.roidict[text]['to'] = value
+ self._emitSelectionChangedSignal(row, col)
+
+ def nameSlot(self, row, col):
+ if col != 0:
+ return
+ if row >= len(self.roilist):
+ _logger.debug("deleting???")
+ return
+ item = self.item(row, col)
+ if item is None:
+ text = ""
+ else:
+ text = str(item.text())
+ if len(text) and (text not in self.roilist):
+ old = self.roilist[row]
+ self.roilist[row] = text
+ self.roidict[text] = {}
+ self.roidict[text].update(self.roidict[old])
+ del self.roidict[old]
+ self._emitSelectionChangedSignal(row, col)
+
+ def _emitSelectionChangedSignal(self, row, col):
+ ddict = {}
+ ddict['event'] = "selectionChanged"
+ ddict['row'] = row
+ ddict['col'] = col
+ ddict['roi'] = self.roidict[self.roilist[row]]
+ ddict['key'] = self.roilist[row]
+ ddict['colheader'] = self.labels[col]
+ ddict['rowheader'] = "%d" % row
+ self.sigROITableSignal.emit(ddict)
+
+
+class CurvesROIDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow.
+
+ It makes the link between the :class:`CurvesROIWidget` and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ :param name: See :class:`QDockWidget`
+ """
+ sigROISignal = qt.Signal(object)
+
+ def __init__(self, parent=None, plot=None, name=None):
+ super(CurvesROIDockWidget, self).__init__(name, parent)
+
+ assert plot is not None
+ self.plot = plot
+
+ self.currentROI = None
+ self._middleROIMarkerFlag = False
+
+ self._isConnected = False # True if connected to plot signals
+ self._isInit = False
+
+ self.roiWidget = CurvesROIWidget(self, name)
+ """Main widget of type :class:`CurvesROIWidget`"""
+
+ # convenience methods to offer a simpler API allowing to ignore
+ # the details of the underlying implementation
+ self.calculateROIs = self.calculateRois
+ self.setRois = self.roiWidget.setRois
+ self.getRois = self.roiWidget.getRois
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.roiWidget)
+
+ self.visibilityChanged.connect(self._visibilityChangedHandler)
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(CurvesROIDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon('plot-roi'))
+ return action
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibilty updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ if not self._isInit:
+ # Deferred ROI widget init finalization
+ self._isInit = True
+ self.roiWidget.sigROIWidgetSignal.connect(self._roiSignal)
+ # initialize with the ICR
+ self._roiSignal({'event': "AddROI"})
+
+ if not self._isConnected:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(
+ self._activeCurveChanged)
+ self._isConnected = True
+
+ self.calculateROIs()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(
+ self._activeCurveChanged)
+ self._isConnected = False
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict['event'] == 'markerMoved':
+
+ label = ddict['label']
+ if label not in ['ROI min', 'ROI max', 'ROI middle']:
+ return
+
+ roiList, roiDict = self.roiWidget.getROIListAndDict()
+ if self.currentROI is None:
+ return
+ if self.currentROI not in roiDict:
+ return
+ x = ddict['x']
+
+ if label == 'ROI min':
+ roiDict[self.currentROI]['from'] = x
+ if self._middleROIMarkerFlag:
+ pos = 0.5 * (roiDict[self.currentROI]['to'] +
+ roiDict[self.currentROI]['from'])
+ self.plot.addXMarker(pos,
+ legend='ROI middle',
+ text='',
+ color='yellow',
+ draggable=True)
+ elif label == 'ROI max':
+ roiDict[self.currentROI]['to'] = x
+ if self._middleROIMarkerFlag:
+ pos = 0.5 * (roiDict[self.currentROI]['to'] +
+ roiDict[self.currentROI]['from'])
+ self.plot.addXMarker(pos,
+ legend='ROI middle',
+ text='',
+ color='yellow',
+ draggable=True)
+ elif label == 'ROI middle':
+ delta = x - 0.5 * (roiDict[self.currentROI]['from'] +
+ roiDict[self.currentROI]['to'])
+ roiDict[self.currentROI]['from'] += delta
+ roiDict[self.currentROI]['to'] += delta
+ self.plot.addXMarker(roiDict[self.currentROI]['from'],
+ legend='ROI min',
+ text='ROI min',
+ color='blue',
+ draggable=True)
+ self.plot.addXMarker(roiDict[self.currentROI]['to'],
+ legend='ROI max',
+ text='ROI max',
+ color='blue',
+ draggable=True)
+ else:
+ return
+ self.calculateROIs(roiList, roiDict)
+ self._emitCurrentROISignal()
+
+ def _roiSignal(self, ddict):
+ """Handle ROI widget signal"""
+ _logger.debug("PlotWindow._roiSignal %s", str(ddict))
+ if ddict['event'] == "AddROI":
+ xmin, xmax = self.plot.getGraphXLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ self.plot.remove('ROI min', kind='marker')
+ self.plot.remove('ROI max', kind='marker')
+ if self._middleROIMarkerFlag:
+ self.remove('ROI middle', kind='marker')
+ roiList, roiDict = self.roiWidget.getROIListAndDict()
+ nrois = len(roiList)
+ if nrois == 0:
+ newroi = "ICR"
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ draggable = False
+ color = 'black'
+ else:
+ for i in range(nrois):
+ i += 1
+ newroi = "newroi %d" % i
+ if newroi not in roiList:
+ break
+ color = 'blue'
+ draggable = True
+ self.plot.addXMarker(fromdata,
+ legend='ROI min',
+ text='ROI min',
+ color=color,
+ draggable=draggable)
+ self.plot.addXMarker(todata,
+ legend='ROI max',
+ text='ROI max',
+ color=color,
+ draggable=draggable)
+ if draggable and self._middleROIMarkerFlag:
+ pos = 0.5 * (fromdata + todata)
+ self.plot.addXMarker(pos,
+ legend='ROI middle',
+ text="",
+ color='yellow',
+ draggable=draggable)
+ roiList.append(newroi)
+ roiDict[newroi] = {}
+ if newroi == "ICR":
+ roiDict[newroi]['type'] = "Default"
+ else:
+ roiDict[newroi]['type'] = self.plot.getGraphXLabel()
+ roiDict[newroi]['from'] = fromdata
+ roiDict[newroi]['to'] = todata
+ self.roiWidget.fillFromROIDict(roilist=roiList,
+ roidict=roiDict,
+ currentroi=newroi)
+ self.currentROI = newroi
+ self.calculateROIs()
+ elif ddict['event'] in ['DelROI', "ResetROI"]:
+ self.plot.remove('ROI min', kind='marker')
+ self.plot.remove('ROI max', kind='marker')
+ if self._middleROIMarkerFlag:
+ self.plot.remove('ROI middle', kind='marker')
+ roiList, roiDict = self.roiWidget.getROIListAndDict()
+ roiDictKeys = list(roiDict.keys())
+ if len(roiDictKeys):
+ currentroi = roiDictKeys[0]
+ else:
+ # create again the ICR
+ ddict = {"event": "AddROI"}
+ return self._roiSignal(ddict)
+
+ self.roiWidget.fillFromROIDict(roilist=roiList,
+ roidict=roiDict,
+ currentroi=currentroi)
+ self.currentROI = currentroi
+
+ elif ddict['event'] == 'LoadROI':
+ self.calculateROIs()
+
+ elif ddict['event'] == 'selectionChanged':
+ _logger.debug("Selection changed")
+ self.roilist, self.roidict = self.roiWidget.getROIListAndDict()
+ fromdata = ddict['roi']['from']
+ todata = ddict['roi']['to']
+ self.plot.remove('ROI min', kind='marker')
+ self.plot.remove('ROI max', kind='marker')
+ if ddict['key'] == 'ICR':
+ draggable = False
+ color = 'black'
+ else:
+ draggable = True
+ color = 'blue'
+ self.plot.addXMarker(fromdata,
+ legend='ROI min',
+ text='ROI min',
+ color=color,
+ draggable=draggable)
+ self.plot.addXMarker(todata,
+ legend='ROI max',
+ text='ROI max',
+ color=color,
+ draggable=draggable)
+ if draggable and self._middleROIMarkerFlag:
+ pos = 0.5 * (fromdata + todata)
+ self.plot.addXMarker(pos,
+ legend='ROI middle',
+ text="",
+ color='yellow',
+ draggable=True)
+ self.currentROI = ddict['key']
+ if ddict['colheader'] in ['From', 'To']:
+ dict0 = {}
+ dict0['event'] = "SetActiveCurveEvent"
+ dict0['legend'] = self.plot.getActiveCurve(just_legend=1)
+ self.plot.setActiveCurve(dict0['legend'])
+ elif ddict['colheader'] == 'Raw Counts':
+ pass
+ elif ddict['colheader'] == 'Net Counts':
+ pass
+ else:
+ self._emitCurrentROISignal()
+
+ else:
+ _logger.debug("Unknown or ignored event %s", ddict['event'])
+
+ def _activeCurveChanged(self, *args):
+ """Recompute ROIs when active curve changed."""
+ self.calculateROIs()
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ if roiList is None or roiDict is None:
+ roiList, roiDict = self.roiWidget.getROIListAndDict()
+
+ activeCurve = self.plot.getActiveCurve(just_legend=False)
+ if activeCurve is None:
+ xproc = None
+ yproc = None
+ self.roiWidget.setHeader()
+ else:
+ x = activeCurve.getXData(copy=False)
+ y = activeCurve.getYData(copy=False)
+ legend = activeCurve.getLegend()
+ idx = numpy.argsort(x, kind='mergesort')
+ xproc = numpy.take(x, idx)
+ yproc = numpy.take(y, idx)
+ self.roiWidget.setHeader('ROIs of %s' % legend)
+
+ for key in roiList:
+ if key == 'ICR':
+ if xproc is not None:
+ roiDict[key]['from'] = xproc.min()
+ roiDict[key]['to'] = xproc.max()
+ else:
+ roiDict[key]['from'] = 0
+ roiDict[key]['to'] = -1
+ fromData = roiDict[key]['from']
+ toData = roiDict[key]['to']
+ if xproc is not None:
+ idx = numpy.nonzero((fromData <= xproc) &
+ (xproc <= toData))[0]
+ if len(idx):
+ xw = xproc[idx]
+ yw = yproc[idx]
+ rawCounts = yw.sum(dtype=numpy.float)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = (deltaY / deltaX)
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = (rawCounts -
+ background.sum(dtype=numpy.float))
+ else:
+ netCounts = 0.0
+ else:
+ rawCounts = 0.0
+ netCounts = 0.0
+ roiDict[key]['rawcounts'] = rawCounts
+ roiDict[key]['netcounts'] = netCounts
+ else:
+ roiDict[key].pop('rawcounts', None)
+ roiDict[key].pop('netcounts', None)
+
+ self.roiWidget.fillFromROIDict(
+ roilist=roiList,
+ roidict=roiDict,
+ currentroi=self.currentROI if self.currentROI in roiList else None)
+
+ def _emitCurrentROISignal(self):
+ ddict = {}
+ ddict['event'] = "currentROISignal"
+ _roiList, roiDict = self.roiWidget.getROIListAndDict()
+ if self.currentROI in roiDict:
+ ddict['ROI'] = roiDict[self.currentROI]
+ else:
+ self.currentROI = None
+ ddict['current'] = self.currentROI
+ self.sigROISignal.emit(ddict)
+
+ def _getAllLimits(self):
+ """Retrieve the limits based on the curves."""
+ curves = self.plot.getAllCurves()
+ if not curves:
+ return 1.0, 1.0, 100., 100.
+
+ xmin, ymin = None, None
+ xmax, ymax = None, None
+
+ for curve in curves:
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+ if xmin is None:
+ xmin = x.min()
+ else:
+ xmin = min(xmin, x.min())
+ if xmax is None:
+ xmax = x.max()
+ else:
+ xmax = max(xmax, x.max())
+ if ymin is None:
+ ymin = y.min()
+ else:
+ ymin = min(ymin, y.min())
+ if ymax is None:
+ ymax = y.max()
+ else:
+ ymax = max(ymax, y.max())
+
+ return xmin, ymin, xmax, ymax
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py
new file mode 100644
index 0000000..780215e
--- /dev/null
+++ b/silx/gui/plot/ImageView.py
@@ -0,0 +1,860 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 2D image with histograms on its sides.
+
+The :class:`ImageView` implements this widget, and
+:class:`ImageViewMainWindow` provides a main window with additional toolbar
+and status bar.
+
+Basic usage of :class:`ImageView` is through the following methods:
+
+- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`ImageView.setImage` to update the displayed image.
+
+The :class:`ImageView` uses :class:`PlotWindow` and also
+exposes :class:`silx.gui.plot.Plot` API for further control
+(plot title, axes labels, adding other images, ...).
+
+For an example of use, see the implementation of :class:`ImageViewMainWindow`,
+and `example/imageview.py`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+import logging
+import numpy
+
+from .. import qt
+
+from . import items, PlotWindow, PlotWidget, PlotActions
+from .Colors import cursorColorForColormap
+from .PlotTools import LimitsToolBar
+from .Profile import ProfileToolBar
+
+
+_logger = logging.getLogger(__name__)
+
+
+# RadarView ###################################################################
+
+class RadarView(qt.QGraphicsView):
+ """Widget presenting a synthetic view of a 2D area and
+ the current visible area.
+
+ Coordinates are as in QGraphicsView:
+ x goes from left to right and y goes from top to bottom.
+ This widget preserves the aspect ratio of the areas.
+
+ The 2D area and the visible area can be set with :meth:`setDataRect`
+ and :meth:`setVisibleRect`.
+ When the visible area has been dragged by the user, its new position
+ is signaled by the *visibleRectDragged* signal.
+
+ It is possible to invert the direction of the axes by using the
+ :meth:`scale` method of QGraphicsView.
+ """
+
+ visibleRectDragged = qt.Signal(float, float, float, float)
+ """Signals that the visible rectangle has been dragged.
+
+ It provides: left, top, width, height in data coordinates.
+ """
+
+ _DATA_PEN = qt.QPen(qt.QColor('white'))
+ _DATA_BRUSH = qt.QBrush(qt.QColor('light gray'))
+ _VISIBLE_PEN = qt.QPen(qt.QColor('red'))
+ _VISIBLE_PEN.setWidth(2)
+ _VISIBLE_PEN.setCosmetic(True)
+ _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
+ _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image'
+
+ _PIXMAP_SIZE = 256
+
+ class _DraggableRectItem(qt.QGraphicsRectItem):
+ """RectItem which signals its change through visibleRectDragged."""
+ def __init__(self, *args, **kwargs):
+ super(RadarView._DraggableRectItem, self).__init__(
+ *args, **kwargs)
+
+ self._previousCursor = None
+ self.setFlag(qt.QGraphicsItem.ItemIsMovable)
+ self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges)
+ self.setAcceptHoverEvents(True)
+ self._ignoreChange = False
+ self._constraint = 0, 0, 0, 0
+
+ def setConstraintRect(self, left, top, width, height):
+ """Set the constraint rectangle for dragging.
+
+ The coordinates are in the _DraggableRectItem coordinate system.
+
+ This constraint only applies to modification through interaction
+ (i.e., this constraint is not applied to change through API).
+
+ If the _DraggableRectItem is smaller than the constraint rectangle,
+ the _DraggableRectItem remains within the constraint rectangle.
+ If the _DraggableRectItem is wider than the constraint rectangle,
+ the constraint rectangle remains within the _DraggableRectItem.
+ """
+ self._constraint = left, left + width, top, top + height
+
+ def setPos(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(RadarView._DraggableRectItem, self).setPos(*args, **kwargs)
+ self._ignoreChange = False
+
+ def moveBy(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(RadarView._DraggableRectItem, self).moveBy(*args, **kwargs)
+ self._ignoreChange = False
+
+ def itemChange(self, change, value):
+ """Callback called before applying changes to the item."""
+ if (change == qt.QGraphicsItem.ItemPositionChange and
+ not self._ignoreChange):
+ # Makes sure that the visible area is in the data
+ # or that data is in the visible area if area is too wide
+ x, y = value.x(), value.y()
+ xMin, xMax, yMin, yMax = self._constraint
+
+ if self.rect().width() <= (xMax - xMin):
+ if x < xMin:
+ value.setX(xMin)
+ elif x > xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+ else:
+ if x > xMin:
+ value.setX(xMin)
+ elif x < xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+
+ if self.rect().height() <= (yMax - yMin):
+ if y < yMin:
+ value.setY(yMin)
+ elif y > yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+ else:
+ if y > yMin:
+ value.setY(yMin)
+ elif y < yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+
+ if self.pos() != value:
+ # Notify change through signal
+ views = self.scene().views()
+ assert len(views) == 1
+ views[0].visibleRectDragged.emit(
+ value.x() + self.rect().left(),
+ value.y() + self.rect().top(),
+ self.rect().width(),
+ self.rect().height())
+
+ return value
+
+ return super(RadarView._DraggableRectItem, self).itemChange(
+ change, value)
+
+ def hoverEnterEvent(self, event):
+ """Called when the mouse enters the rectangle area"""
+ self._previousCursor = self.cursor()
+ self.setCursor(qt.Qt.OpenHandCursor)
+
+ def hoverLeaveEvent(self, event):
+ """Called when the mouse leaves the rectangle area"""
+ if self._previousCursor is not None:
+ self.setCursor(self._previousCursor)
+ self._previousCursor = None
+
+ def __init__(self, parent=None):
+ self._scene = qt.QGraphicsScene()
+ self._dataRect = self._scene.addRect(0, 0, 1, 1,
+ self._DATA_PEN,
+ self._DATA_BRUSH)
+ self._visibleRect = self._DraggableRectItem(0, 0, 1, 1)
+ self._visibleRect.setPen(self._VISIBLE_PEN)
+ self._visibleRect.setBrush(self._VISIBLE_BRUSH)
+ self._scene.addItem(self._visibleRect)
+
+ super(RadarView, self).__init__(self._scene, parent)
+ self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ self.setStyleSheet('border: 0px')
+ self.setToolTip(self._TOOLTIP)
+
+ def sizeHint(self):
+ # """Overridden to avoid sizeHint to depend on content size."""
+ return self.minimumSizeHint()
+
+ def wheelEvent(self, event):
+ # """Overridden to disable vertical scrolling with wheel."""
+ event.ignore()
+
+ def resizeEvent(self, event):
+ # """Overridden to fit current content to new size."""
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ super(RadarView, self).resizeEvent(event)
+
+ def setDataRect(self, left, top, width, height):
+ """Set the bounds of the data rectangular area.
+
+ This sets the coordinate system.
+ """
+ self._dataRect.setRect(left, top, width, height)
+ self._visibleRect.setConstraintRect(left, top, width, height)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def setVisibleRect(self, left, top, width, height):
+ """Set the visible rectangular area.
+
+ The coordinates are relative to the data rect.
+ """
+ self._visibleRect.setRect(0, 0, width, height)
+ self._visibleRect.setPos(left, top)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+
+# ImageView ###################################################################
+
+class ImageView(PlotWindow):
+ """Display a single image with horizontal and vertical histograms.
+
+ Use :meth:`setImage` to control the displayed image.
+ This class also provides the :class:`silx.gui.plot.Plot` API.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ HISTOGRAMS_COLOR = 'blue'
+ """Color to use for the side histograms."""
+
+ HISTOGRAMS_HEIGHT = 200
+ """Height in pixels of the side histograms."""
+
+ IMAGE_MIN_SIZE = 200
+ """Minimum size in pixels of the image area."""
+
+ # Qt signals
+ valueChanged = qt.Signal(float, float, float)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+
+ When the cursor is over an histogram, either row or column is Nan
+ and the provided data value is the histogram value
+ (i.e., the sum along the corresponding row/column).
+ Row and columns are either Nan or integer values.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._imageLegend = '__ImageView__image' + str(id(self))
+ self._cache = None # Store currently visible data information
+ self._updatingLimits = False
+
+ super(ImageView, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=False, mask=True)
+ if parent is None:
+ self.setWindowTitle('ImageView')
+
+ self._initWidgets(backend)
+
+ self.profile = ProfileToolBar(plot=self)
+ """"Profile tools attached to this plot.
+
+ See :class:`silx.gui.plot.PlotTools.ProfileToolBar`
+ """
+
+ self.addToolBar(self.profile)
+
+ # Sync PlotBackend and ImageView
+ self._updateYAxisInverted()
+
+ def _initWidgets(self, backend):
+ """Set-up layout and plots."""
+ # Monkey-patch for histogram size
+ # alternative: create a layout that does not use widget size hints
+ def sizeHint():
+ return qt.QSize(self.HISTOGRAMS_HEIGHT, self.HISTOGRAMS_HEIGHT)
+
+ self._histoHPlot = PlotWidget(backend=backend)
+ self._histoHPlot.setInteractiveMode('zoom')
+ self._histoHPlot.setCallback(self._histoHPlotCB)
+ self._histoHPlot.getWidgetHandle().sizeHint = sizeHint
+ self._histoHPlot.getWidgetHandle().minimumSizeHint = sizeHint
+
+ self.setPanWithArrowKeys(True)
+
+ self.setInteractiveMode('zoom') # Color set in setColormap
+ self.sigPlotSignal.connect(self._imagePlotCB)
+ self.sigSetYAxisInverted.connect(self._updateYAxisInverted)
+ self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
+
+ self._histoVPlot = PlotWidget(backend=backend)
+ self._histoVPlot.setInteractiveMode('zoom')
+ self._histoVPlot.setCallback(self._histoVPlotCB)
+ self._histoVPlot.getWidgetHandle().sizeHint = sizeHint
+ self._histoVPlot.getWidgetHandle().minimumSizeHint = sizeHint
+
+ self._radarView = RadarView()
+ self._radarView.visibleRectDragged.connect(self._radarViewCB)
+
+ self._layout = qt.QGridLayout()
+ self._layout.addWidget(self.getWidgetHandle(), 0, 0)
+ self._layout.addWidget(self._histoVPlot.getWidgetHandle(), 0, 1)
+ self._layout.addWidget(self._histoHPlot.getWidgetHandle(), 1, 0)
+ self._layout.addWidget(self._radarView, 1, 1)
+
+ self._layout.setColumnMinimumWidth(0, self.IMAGE_MIN_SIZE)
+ self._layout.setColumnStretch(0, 1)
+ self._layout.setColumnMinimumWidth(1, self.HISTOGRAMS_HEIGHT)
+ self._layout.setColumnStretch(1, 0)
+
+ self._layout.setRowMinimumHeight(0, self.IMAGE_MIN_SIZE)
+ self._layout.setRowStretch(0, 1)
+ self._layout.setRowMinimumHeight(1, self.HISTOGRAMS_HEIGHT)
+ self._layout.setRowStretch(1, 0)
+
+ self._layout.setSpacing(0)
+ self._layout.setContentsMargins(0, 0, 0, 0)
+
+ centralWidget = qt.QWidget()
+ centralWidget.setLayout(self._layout)
+ self.setCentralWidget(centralWidget)
+
+ def _dirtyCache(self):
+ self._cache = None
+
+ def _updateHistograms(self):
+ """Update histograms content using current active image."""
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ wasUpdatingLimits = self._updatingLimits
+ self._updatingLimits = True
+
+ data = activeImage.getData(copy=False)
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ height, width = data.shape
+
+ xMin, xMax = self.getGraphXLimits()
+ yMin, yMax = self.getGraphYLimits()
+
+ # Convert plot area limits to image coordinates
+ # and work in image coordinates (i.e., in pixels)
+ xMin = int((xMin - origin[0]) / scale[0])
+ xMax = int((xMax - origin[0]) / scale[0])
+ yMin = int((yMin - origin[1]) / scale[1])
+ yMax = int((yMax - origin[1]) / scale[1])
+
+ if (xMin < width and xMax >= 0 and
+ yMin < height and yMax >= 0):
+ # The image is at least partly in the plot area
+ # Get the visible bounds in image coords (i.e., in pixels)
+ subsetXMin = 0 if xMin < 0 else xMin
+ subsetXMax = (width if xMax >= width else xMax) + 1
+ subsetYMin = 0 if yMin < 0 else yMin
+ subsetYMax = (height if yMax >= height else yMax) + 1
+
+ if (self._cache is None or
+ subsetXMin != self._cache['dataXMin'] or
+ subsetXMax != self._cache['dataXMax'] or
+ subsetYMin != self._cache['dataYMin'] or
+ subsetYMax != self._cache['dataYMax']):
+ # The visible area of data has changed, update histograms
+
+ # Rebuild histograms for visible area
+ visibleData = data[subsetYMin:subsetYMax,
+ subsetXMin:subsetXMax]
+ histoHVisibleData = numpy.sum(visibleData, axis=0)
+ histoVVisibleData = numpy.sum(visibleData, axis=1)
+
+ self._cache = {
+ 'dataXMin': subsetXMin,
+ 'dataXMax': subsetXMax,
+ 'dataYMin': subsetYMin,
+ 'dataYMax': subsetYMax,
+
+ 'histoH': histoHVisibleData,
+ 'histoHMin': numpy.min(histoHVisibleData),
+ 'histoHMax': numpy.max(histoHVisibleData),
+
+ 'histoV': histoVVisibleData,
+ 'histoVMin': numpy.min(histoVVisibleData),
+ 'histoVMax': numpy.max(histoVVisibleData)
+ }
+
+ # Convert to histogram curve and update plots
+ # Taking into account origin and scale
+ coords = numpy.arange(2 * histoHVisibleData.size)
+ xCoords = (coords + 1) // 2 + subsetXMin
+ xCoords = origin[0] + scale[0] * xCoords
+ xData = numpy.take(histoHVisibleData, coords // 2)
+ self._histoHPlot.addCurve(xCoords, xData,
+ xlabel='', ylabel='',
+ replace=False,
+ color=self.HISTOGRAMS_COLOR,
+ linestyle='-',
+ selectable=False)
+ vMin = self._cache['histoHMin']
+ vMax = self._cache['histoHMax']
+ vOffset = 0.1 * (vMax - vMin)
+ if vOffset == 0.:
+ vOffset = 1.
+ self._histoHPlot.setGraphYLimits(vMin - vOffset,
+ vMax + vOffset)
+
+ coords = numpy.arange(2 * histoVVisibleData.size)
+ yCoords = (coords + 1) // 2 + subsetYMin
+ yCoords = origin[1] + scale[1] * yCoords
+ yData = numpy.take(histoVVisibleData, coords // 2)
+ self._histoVPlot.addCurve(yData, yCoords,
+ xlabel='', ylabel='',
+ replace=False,
+ color=self.HISTOGRAMS_COLOR,
+ linestyle='-',
+ selectable=False)
+ vMin = self._cache['histoVMin']
+ vMax = self._cache['histoVMax']
+ vOffset = 0.1 * (vMax - vMin)
+ if vOffset == 0.:
+ vOffset = 1.
+ self._histoVPlot.setGraphXLimits(vMin - vOffset,
+ vMax + vOffset)
+ else:
+ self._dirtyCache()
+ self._histoHPlot.remove(kind='curve')
+ self._histoVPlot.remove(kind='curve')
+
+ self._updatingLimits = wasUpdatingLimits
+
+ def _updateRadarView(self):
+ """Update radar view visible area.
+
+ Takes care of y coordinate conversion.
+ """
+ xMin, xMax = self.getGraphXLimits()
+ yMin, yMax = self.getGraphYLimits()
+ self._radarView.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin)
+
+ # Plots event listeners
+
+ def _imagePlotCB(self, eventDict):
+ """Callback for imageView plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData(copy=False)
+ height, width = data.shape
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ if (eventDict['x'] >= origin[0] and
+ eventDict['y'] >= origin[1]):
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if x >= 0 and x < width and y >= 0 and y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+
+ elif eventDict['event'] == 'limitsChanged':
+ # Do not handle histograms limitsChanged while
+ # updating their limits from here.
+ self._updatingLimits = True
+
+ # Refresh histograms
+ self._updateHistograms()
+
+ # could use eventDict['xdata'], eventDict['ydata'] instead
+ xMin, xMax = self.getGraphXLimits()
+ yMin, yMax = self.getGraphYLimits()
+
+ # Set horizontal histo limits
+ self._histoHPlot.setGraphXLimits(xMin, xMax)
+
+ # Set vertical histo limits
+ self._histoVPlot.setGraphYLimits(yMin, yMax)
+
+ self._updateRadarView()
+
+ self._updatingLimits = False
+
+ def _histoHPlotCB(self, eventDict):
+ """Callback for horizontal histogram plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ if self._cache is not None:
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ xOrigin = activeImage.getOrigin()[0]
+ xScale = activeImage.getScale()[0]
+
+ minValue = xOrigin + xScale * self._cache['dataXMin']
+
+ if eventDict['x'] >= minValue:
+ data = self._cache['histoH']
+ column = int((eventDict['x'] - minValue) / xScale)
+ if column >= 0 and column < data.shape[0]:
+ self.valueChanged.emit(
+ float('nan'),
+ float(column + self._cache['dataXMin']),
+ data[column])
+
+ elif eventDict['event'] == 'limitsChanged':
+ if (not self._updatingLimits and
+ eventDict['xdata'] != self.getGraphXLimits()):
+ xMin, xMax = eventDict['xdata']
+ self.setGraphXLimits(xMin, xMax)
+
+ def _histoVPlotCB(self, eventDict):
+ """Callback for vertical histogram plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ if self._cache is not None:
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ yOrigin = activeImage.getOrigin()[1]
+ yScale = activeImage.getScale()[1]
+
+ minValue = yOrigin + yScale * self._cache['dataYMin']
+
+ if eventDict['y'] >= minValue:
+ data = self._cache['histoV']
+ row = int((eventDict['y'] - minValue) / yScale)
+ if row >= 0 and row < data.shape[0]:
+ self.valueChanged.emit(
+ float(row + self._cache['dataYMin']),
+ float('nan'),
+ data[row])
+
+ elif eventDict['event'] == 'limitsChanged':
+ if (not self._updatingLimits and
+ eventDict['ydata'] != self.getGraphYLimits()):
+ yMin, yMax = eventDict['ydata']
+ self.setGraphYLimits(yMin, yMax)
+
+ def _radarViewCB(self, left, top, width, height):
+ """Slot for radar view visible rectangle changes."""
+ if not self._updatingLimits:
+ # Takes care of Y axis conversion
+ self.setLimits(left, left + width, top, top + height)
+
+ def _updateYAxisInverted(self, inverted=None):
+ """Sync image, vertical histogram and radar view axis orientation."""
+ if inverted is None:
+ # Do not perform this when called from plot signal
+ inverted = self.isYAxisInverted()
+
+ self._histoVPlot.setYAxisInverted(inverted)
+
+ # Use scale to invert radarView
+ # RadarView default Y direction is from top to bottom
+ # As opposed to Plot. So invert RadarView when Plot is NOT inverted.
+ self._radarView.resetTransform()
+ if not inverted:
+ self._radarView.scale(1., -1.)
+ self._updateRadarView()
+
+ self._radarView.update()
+
+ def _activeImageChangedSlot(self, previous, legend):
+ """Handle Plot active image change.
+
+ Resets side histograms cache
+ """
+ self._dirtyCache()
+ self._updateHistograms()
+
+ def getHistogram(self, axis):
+ """Return the histogram and corresponding row or column extent.
+
+ The returned value when an histogram is available is a dict with keys:
+
+ - 'data': numpy array of the histogram values.
+ - 'extent': (start, end) row or column index.
+ end index is not included in the histogram.
+
+ :param str axis: 'x' for horizontal, 'y' for vertical
+ :return: The histogram and its extent as a dict or None.
+ :rtype: dict
+ """
+ assert axis in ('x', 'y')
+ if self._cache is None:
+ return None
+ else:
+ if axis == 'x':
+ return dict(
+ data=numpy.array(self._cache['histoH'], copy=True),
+ extent=(self._cache['dataXMin'], self._cache['dataXMax']))
+ else:
+ return dict(
+ data=numpy.array(self._cache['histoV'], copy=True),
+ extent=(self._cache['dataYMin'], self._cache['dataYMax']))
+
+ def radarView(self):
+ """Get the lower right radarView widget."""
+ return self._radarView
+
+ def setRadarView(self, radarView):
+ """Change the lower right radarView widget.
+
+ :param RadarView radarView: Widget subclassing RadarView to replace
+ the lower right corner widget.
+ """
+ self._radarView.visibleRectDragged.disconnect(self._radarViewCB)
+ self._radarView = radarView
+ self._radarView.visibleRectDragged.connect(self._radarViewCB)
+ self._layout.addWidget(self._radarView, 1, 1)
+
+ self._updateYAxisInverted()
+
+ # High-level API
+
+ def getColormap(self):
+ """Get the default colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ return self.getDefaultColormap()
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the default colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True)
+ or range provided by keys 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or the description of the colormap as a dict.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale (True)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ cmapDict = self.getDefaultColormap()
+
+ if isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ assert normalization is None
+ assert autoscale is None
+ assert vmin is None
+ assert vmax is None
+ assert colors is None
+ for key, value in colormap.items():
+ cmapDict[key] = value
+
+ else:
+ if colormap is not None:
+ cmapDict['name'] = colormap
+ if normalization is not None:
+ cmapDict['normalization'] = normalization
+ if autoscale is not None:
+ cmapDict['autoscale'] = autoscale
+ if vmin is not None:
+ cmapDict['vmin'] = vmin
+ if vmax is not None:
+ cmapDict['vmax'] = vmax
+ if colors is not None:
+ cmapDict['colors'] = colors
+
+ cursorColor = cursorColorForColormap(cmapDict['name'])
+ self.setInteractiveMode('zoom', color=cursorColor)
+
+ self.setDefaultColormap(cmapDict)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(self.getColormap())
+
+ def setImage(self, image, origin=(0, 0), scale=(1., 1.),
+ copy=True, reset=True):
+ """Set the image to display.
+
+ :param image: A 2D array representing the image or None to empty plot.
+ :type image: numpy.ndarray-like with 2 dimensions or None.
+ :param origin: The (x, y) position of the origin of the image.
+ Default: (0, 0).
+ The origin is the lower left corner of the image when
+ the Y axis is not inverted.
+ :type origin: Tuple of 2 floats: (origin x, origin y).
+ :param scale: The scale factor to apply to the image on X and Y axes.
+ Default: (1, 1).
+ It is the size of a pixel in the coordinates of the axes.
+ Scales must be positive numbers.
+ :type scale: Tuple of 2 floats: (scale x, scale y).
+ :param bool copy: Whether to copy image data (default) or not.
+ :param bool reset: Whether to reset zoom and ROI (default) or not.
+ """
+ self._dirtyCache()
+
+ assert len(origin) == 2
+ assert len(scale) == 2
+ assert scale[0] > 0
+ assert scale[1] > 0
+
+ if image is None:
+ self.remove(self._imageLegend, kind='image')
+ return
+
+ data = numpy.array(image, order='C', copy=copy)
+ assert data.size != 0
+ assert len(data.shape) == 2
+ height, width = data.shape
+
+ self.addImage(data,
+ legend=self._imageLegend,
+ origin=origin, scale=scale,
+ colormap=self.getColormap(),
+ replace=False)
+ self.setActiveImage(self._imageLegend)
+ self._updateHistograms()
+
+ self._radarView.setDataRect(origin[0],
+ origin[1],
+ width * scale[0],
+ height * scale[1])
+
+ if reset:
+ self.resetZoom()
+
+
+# ImageViewMainWindow #########################################################
+
+class ImageViewMainWindow(ImageView):
+ """:class:`ImageView` with additional toolbars
+
+ Adds extra toolbar and a status bar to :class:`ImageView`.
+ """
+ def __init__(self, parent=None, backend=None):
+ self._dataInfo = None
+ super(ImageViewMainWindow, self).__init__(parent, backend)
+ self.setWindowFlags(qt.Qt.Window)
+
+ self.setGraphXLabel('X')
+ self.setGraphYLabel('Y')
+ self.setGraphTitle('Image')
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
+
+ self.statusBar()
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self.saveAction)
+ menu.addAction(self.printAction)
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self.copyAction)
+ menu.addSeparator()
+ menu.addAction(self.resetZoomAction)
+ menu.addAction(self.colormapAction)
+ menu.addAction(PlotActions.KeepAspectRatioAction(self, self))
+ menu.addAction(PlotActions.YAxisInvertedAction(self, self))
+
+ menu = self.menuBar().addMenu('Profile')
+ menu.addAction(self.profile.browseAction)
+ menu.addAction(self.profile.hLineAction)
+ menu.addAction(self.profile.vLineAction)
+ menu.addAction(self.profile.lineAction)
+ menu.addAction(self.profile.clearAction)
+
+ # Connect to ImageView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def _statusBarSlot(self, row, column, value):
+ """Update status bar with coordinates/value from plots."""
+ if numpy.isnan(row):
+ msg = 'Column: %d, Sum: %g' % (int(column), value)
+ elif numpy.isnan(column):
+ msg = 'Row: %d, Sum: %g' % (int(row), value)
+ else:
+ msg = 'Position: (%d, %d), Value: %g' % (int(row), int(column),
+ value)
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ def setImage(self, image, *args, **kwargs):
+ """Set the displayed image.
+
+ See :meth:`ImageView.setImage` for details.
+ """
+ if hasattr(image, 'dtype') and hasattr(image, 'shape'):
+ assert len(image.shape) == 2
+ height, width = image.shape
+ self._dataInfo = 'Data: %dx%d (%s)' % (width, height,
+ str(image.dtype))
+ self.statusBar().showMessage(self._dataInfo)
+ else:
+ self._dataInfo = None
+
+ # Set the new image in ImageView widget
+ super(ImageViewMainWindow, self).setImage(image, *args, **kwargs)
+ self.setStatusBar(None)
diff --git a/silx/gui/plot/Interaction.py b/silx/gui/plot/Interaction.py
new file mode 100644
index 0000000..f09b9bc
--- /dev/null
+++ b/silx/gui/plot/Interaction.py
@@ -0,0 +1,300 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an implementation of state machines for interaction.
+
+Sample code of a state machine with two states ('idle' and 'active')
+with transitions on left button press/release:
+
+.. code-block:: python
+
+ from silx.gui.plot.Interaction import *
+
+ class SampleStateMachine(StateMachine):
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('active')
+
+ class Active(State):
+ def enterState(self):
+ print('Enabled') # Handle enter active state here
+
+ def leaveState(self):
+ print('Disabled') # Handle leave active state here
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('idle')
+
+ def __init__(self):
+ # State machine has 2 states
+ states = {
+ 'idle': SampleStateMachine.Idle,
+ 'active': SampleStateMachine.Active
+ }
+ super(TwoStates, self).__init__(states, 'idle')
+ # idle is the initial state
+
+ stateMachine = SampleStateMachine()
+
+ # Triggers a transition to the Active state:
+ stateMachine.handleEvent('press', 0, 0, LEFT_BTN)
+
+ # Triggers a transition to the Idle state:
+ stateMachine.handleEvent('release', 0, 0, LEFT_BTN)
+
+See :class:`ClickOrDrag` for another example of a state machine.
+
+See `Renaud Blanch, Michel Beaudouin-Lafon.
+Programming Rich Interactions using the Hierarchical State Machine Toolkit.
+In Proceedings of AVI 2006. p 51-58.
+<http://iihm.imag.fr/en/publication/BB06a/>`_
+for a discussion of using (hierarchical) state machines for interaction.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import weakref
+
+
+# state machine ###############################################################
+
+class State(object):
+ """Base class for the states of a state machine.
+
+ This class is meant to be subclassed.
+ """
+
+ def __init__(self, machine):
+ """State instances should be created by the :class:`StateMachine`.
+
+ They are not intended to be used outside this context.
+
+ :param machine: The state machine instance this state belongs to.
+ :type machine: StateMachine
+ """
+ self._machineRef = weakref.ref(machine) # Prevent cyclic reference
+
+ @property
+ def machine(self):
+ """The state machine this state belongs to.
+
+ Useful to access data or methods that are shared across states.
+ """
+ machine = self._machineRef()
+ if machine is not None:
+ return machine
+ else:
+ raise RuntimeError("Associated StateMachine is not valid")
+
+ def goto(self, state, *args, **kwargs):
+ """Performs a transition to a new state.
+
+ Extra arguments are passed to the :meth:`enterState` method of the
+ new state.
+
+ :param str state: The name of the state to go to.
+ """
+ self.machine._goto(state, *args, **kwargs)
+
+ def enterState(self, *args, **kwargs):
+ """Called when the state machine enters this state.
+
+ Arguments are those provided to the :meth:`goto` method that
+ triggered the transition to this state.
+ """
+ pass
+
+ def leaveState(self):
+ """Called when the state machine leaves this state
+ (i.e., when :meth:`goto` is called).
+ """
+ pass
+
+
+class StateMachine(object):
+ """State machine controller.
+
+ This is the entry point of a state machine.
+ It is in charge of dispatching received event and handling the
+ current active state.
+ """
+
+ def __init__(self, states, initState, *args, **kwargs):
+ """Create a state machine controller with an initial state.
+
+ Extra arguments are passed to the :meth:`enterState` method
+ of the initState.
+
+ :param states: All states of the state machine
+ :type states: dict of: {str name: State subclass}
+ :param str initState: Key of the initial state in states
+ """
+ self.states = states
+
+ self.state = self.states[initState](self)
+ self.state.enterState(*args, **kwargs)
+
+ def _goto(self, state, *args, **kwargs):
+ self.state.leaveState()
+ self.state = self.states[state](self)
+ self.state.enterState(*args, **kwargs)
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ """Process an event with the state machine.
+
+ This method looks up for an event handler in the current state
+ and then in the :class:`StateMachine` instance.
+ Handler are looked up as 'onEventName' method.
+ If a handler is found, it is called with the provided extra
+ arguments, and this method returns the return value of the
+ handler.
+ If no handler is found, this method returns None.
+
+ :param str eventName: Name of the event to handle
+ :returns: The return value of the handler or None
+ """
+ handlerName = 'on' + eventName[0].upper() + eventName[1:]
+ try:
+ handler = getattr(self.state, handlerName)
+ except AttributeError:
+ try:
+ handler = getattr(self, handlerName)
+ except AttributeError:
+ handler = None
+ if handler is not None:
+ return handler(*args, **kwargs)
+
+
+# clickOrDrag #################################################################
+
+LEFT_BTN = 'left'
+"""Left mouse button."""
+
+RIGHT_BTN = 'right'
+"""Right mouse button."""
+
+MIDDLE_BTN = 'middle'
+"""Middle mouse button."""
+
+
+class ClickOrDrag(StateMachine):
+ """State machine for left and right click and left drag interaction.
+
+ It is intended to be used through subclassing by overriding
+ :meth:`click`, :meth:`beginDrag`, :meth:`drag` and :meth:`endDrag`.
+ """
+
+ DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('clickOrDrag', x, y)
+ return True
+ elif btn == RIGHT_BTN:
+ self.goto('rightClick', x, y)
+ return True
+
+ class RightClick(State):
+ def onMove(self, x, y):
+ self.goto('idle')
+
+ def onRelease(self, x, y, btn):
+ if btn == RIGHT_BTN:
+ self.machine.click(x, y, btn)
+ self.goto('idle')
+
+ class ClickOrDrag(State):
+ def enterState(self, x, y):
+ self.initPos = x, y
+
+ def onMove(self, x, y):
+ dx2 = (x - self.initPos[0]) ** 2
+ dy2 = (y - self.initPos[1]) ** 2
+ if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto('drag', self.initPos, (x, y))
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.click(x, y, btn)
+ self.goto('idle')
+
+ class Drag(State):
+ def enterState(self, initPos, curPos):
+ self.initPos = initPos
+ self.machine.beginDrag(*initPos)
+ self.machine.drag(*curPos)
+
+ def onMove(self, x, y):
+ self.machine.drag(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endDrag(self.initPos, (x, y))
+ self.goto('idle')
+
+ def __init__(self):
+ states = {
+ 'idle': ClickOrDrag.Idle,
+ 'rightClick': ClickOrDrag.RightClick,
+ 'clickOrDrag': ClickOrDrag.ClickOrDrag,
+ 'drag': ClickOrDrag.Drag
+ }
+ super(ClickOrDrag, self).__init__(states, 'idle')
+
+ def click(self, x, y, btn):
+ """Called upon a left or right button click.
+
+ To override in a subclass.
+ """
+ pass
+
+ def beginDrag(self, x, y):
+ """Called at the beginning of a drag gesture with left button
+ pressed.
+
+ To override in a subclass.
+ """
+ pass
+
+ def drag(self, x, y):
+ """Called on mouse moved during a drag gesture.
+
+ To override in a subclass.
+ """
+ pass
+
+ def endDrag(self, startPoint, endPoint):
+ """Called at the end of a drag gesture when the left button is
+ released.
+
+ To override in a subclass.
+ """
+ pass
diff --git a/silx/gui/plot/LegendSelector.py b/silx/gui/plot/LegendSelector.py
new file mode 100644
index 0000000..3af9050
--- /dev/null
+++ b/silx/gui/plot/LegendSelector.py
@@ -0,0 +1,1087 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget displaying curves legends and allowing to operate on curves.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__data__ = "28/04/2016"
+
+
+import logging
+import weakref
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+# Build all symbols
+# Courtesy of the pyqtgraph project
+Symbols = dict([(name, qt.QPainterPath())
+ for name in ['o', 's', 't', 'd', '+', 'x', '.', ',']])
+Symbols['o'].addEllipse(qt.QRectF(.1, .1, .8, .8))
+Symbols['.'].addEllipse(qt.QRectF(.3, .3, .4, .4))
+Symbols[','].addEllipse(qt.QRectF(.4, .4, .2, .2))
+Symbols['s'].addRect(qt.QRectF(.1, .1, .8, .8))
+
+coords = {
+ 't': [(0.5, 0.), (.1, .8), (.9, .8)],
+ 'd': [(0.1, 0.5), (0.5, 0.), (0.9, 0.5), (0.5, 1.)],
+ '+': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
+ (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
+ (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)],
+ 'x': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
+ (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
+ (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)]
+}
+for s, c in coords.items():
+ Symbols[s].moveTo(*c[0])
+ for x, y in c[1:]:
+ Symbols[s].lineTo(x, y)
+ Symbols[s].closeSubpath()
+tr = qt.QTransform()
+tr.rotate(45)
+Symbols['x'].translate(qt.QPointF(-0.5, -0.5))
+Symbols['x'] = tr.map(Symbols['x'])
+Symbols['x'].translate(qt.QPointF(0.5, 0.5))
+
+NoSymbols = (None, 'None', 'none', '', ' ')
+"""List of values resulting in no symbol being displayed for a curve"""
+
+
+LineStyles = {
+ None: qt.Qt.NoPen,
+ 'None': qt.Qt.NoPen,
+ 'none': qt.Qt.NoPen,
+ '': qt.Qt.NoPen,
+ ' ': qt.Qt.NoPen,
+ '-': qt.Qt.SolidLine,
+ '--': qt.Qt.DashLine,
+ ':': qt.Qt.DotLine,
+ '-.': qt.Qt.DashDotLine
+}
+"""Conversion from matplotlib-like linestyle to Qt"""
+
+NoLineStyle = (None, 'None', 'none', '', ' ')
+"""List of style values resulting in no line being displayed for a curve"""
+
+
+class LegendIcon(qt.QWidget):
+ """Object displaying a curve linestyle and symbol."""
+
+ def __init__(self, parent=None):
+ super(LegendIcon, self).__init__(parent)
+
+ # Visibilities
+ self.showLine = True
+ self.showSymbol = True
+
+ # Line attributes
+ self.lineStyle = qt.Qt.NoPen
+ self.lineWidth = 1.
+ self.lineColor = qt.Qt.green
+
+ self.symbol = ''
+ # Symbol attributes
+ self.symbolStyle = qt.Qt.SolidPattern
+ self.symbolColor = qt.Qt.green
+ self.symbolOutlineBrush = qt.QBrush(qt.Qt.white)
+
+ # Control widget size: sizeHint "is the only acceptable
+ # alternative, so the widget can never grow or shrink"
+ # (c.f. Qt Doc, enum QSizePolicy::Policy)
+ self.setSizePolicy(qt.QSizePolicy.Fixed,
+ qt.QSizePolicy.Fixed)
+
+ def sizeHint(self):
+ return qt.QSize(50, 15)
+
+ # Modify Symbol
+ def setSymbol(self, symbol):
+ symbol = str(symbol)
+ if symbol not in NoSymbols:
+ if symbol not in Symbols:
+ raise ValueError("Unknown symbol: <%s>" % symbol)
+ self.symbol = symbol
+ # self.update() after set...?
+ # Does not seem necessary
+
+ def setSymbolColor(self, color):
+ """
+ :param color: determines the symbol color
+ :type style: qt.QColor
+ """
+ self.symbolColor = qt.QColor(color)
+
+ # Modify Line
+
+ def setLineColor(self, color):
+ self.lineColor = qt.QColor(color)
+
+ def setLineWidth(self, width):
+ self.lineWidth = float(width)
+
+ def setLineStyle(self, style):
+ """Set the linestyle.
+
+ Possible line styles:
+
+ - '', ' ', 'None': No line
+ - '-': solid
+ - '--': dashed
+ - ':': dotted
+ - '-.': dash and dot
+
+ :param str style: The linestyle to use
+ """
+ if style not in LineStyles:
+ raise ValueError('Unknown style: %s', style)
+ self.lineStyle = LineStyles[style]
+
+ # Paint
+
+ def paintEvent(self, event):
+ """
+ :param event: event
+ :type event: QPaintEvent
+ """
+ painter = qt.QPainter(self)
+ self.paint(painter, event.rect(), self.palette())
+
+ def paint(self, painter, rect, palette):
+ painter.save()
+ painter.setRenderHint(qt.QPainter.Antialiasing)
+ # Scale painter to the icon height
+ # current -> width = 2.5, height = 1.0
+ scale = float(self.height())
+ ratio = float(self.width()) / scale
+ painter.scale(scale,
+ scale)
+ symbolOffset = qt.QPointF(.5 * (ratio - 1.), 0.)
+ # Determine and scale offset
+ offset = qt.QPointF(float(rect.left()) / scale, float(rect.top()) / scale)
+ # Draw BG rectangle (for debugging)
+ # bottomRight = qt.QPointF(
+ # float(rect.right())/scale,
+ # float(rect.bottom())/scale)
+ # painter.fillRect(qt.QRectF(offset, bottomRight),
+ # qt.QBrush(qt.Qt.green))
+ llist = []
+ if self.showLine:
+ linePath = qt.QPainterPath()
+ linePath.moveTo(0., 0.5)
+ linePath.lineTo(ratio, 0.5)
+ # linePath.lineTo(2.5, 0.5)
+ linePen = qt.QPen(
+ qt.QBrush(self.lineColor),
+ (self.lineWidth / self.height()),
+ self.lineStyle,
+ qt.Qt.FlatCap
+ )
+ llist.append((linePath,
+ linePen,
+ qt.QBrush(self.lineColor)))
+ if (self.showSymbol and len(self.symbol) and
+ self.symbol not in NoSymbols):
+ # PITFALL ahead: Let this be a warning to others
+ # symbolPath = Symbols[self.symbol]
+ # Copy before translate! Dict is a mutable type
+ symbolPath = qt.QPainterPath(Symbols[self.symbol])
+ symbolPath.translate(symbolOffset)
+ symbolBrush = qt.QBrush(
+ self.symbolColor,
+ self.symbolStyle
+ )
+ symbolPen = qt.QPen(
+ self.symbolOutlineBrush, # Brush
+ 1. / self.height(), # Width
+ qt.Qt.SolidLine # Style
+ )
+ llist.append((symbolPath,
+ symbolPen,
+ symbolBrush))
+ # Draw
+ for path, pen, brush in llist:
+ path.translate(offset)
+ painter.setPen(pen)
+ painter.setBrush(brush)
+ painter.drawPath(path)
+ painter.restore()
+
+
+class LegendModel(qt.QAbstractListModel):
+ """Data model of curve legends.
+
+ It holds the information of the curve:
+
+ - color
+ - line width
+ - line style
+ - visibility of the lines
+ - symbol
+ - visibility of the symbols
+ """
+ iconColorRole = qt.Qt.UserRole + 0
+ iconLineWidthRole = qt.Qt.UserRole + 1
+ iconLineStyleRole = qt.Qt.UserRole + 2
+ showLineRole = qt.Qt.UserRole + 3
+ iconSymbolRole = qt.Qt.UserRole + 4
+ showSymbolRole = qt.Qt.UserRole + 5
+
+ def __init__(self, legendList=None, parent=None):
+ super(LegendModel, self).__init__(parent)
+ if legendList is None:
+ legendList = []
+ self.legendList = []
+ self.insertLegendList(0, legendList)
+
+ def __getitem__(self, idx):
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+ return self.legendList[idx]
+
+ def rowCount(self, modelIndex=None):
+ return len(self.legendList)
+
+ def flags(self, index):
+ return (qt.Qt.ItemIsEditable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsSelectable)
+
+ def data(self, modelIndex, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+
+ item = self.legendList[idx]
+ if role == qt.Qt.DisplayRole:
+ # Data to be rendered in the form of text
+ legend = str(item[0])
+ return legend
+ elif role == qt.Qt.SizeHintRole:
+ # size = qt.QSize(200,50)
+ _logger.warning('LegendModel -- size hint role not implemented')
+ return qt.QSize()
+ elif role == qt.Qt.TextAlignmentRole:
+ alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
+ return alignment
+ elif role == qt.Qt.BackgroundRole:
+ # Background color, must be QBrush
+ if idx % 2:
+ brush = qt.QBrush(qt.QColor(240, 240, 240))
+ else:
+ brush = qt.QBrush(qt.Qt.white)
+ return brush
+ elif role == qt.Qt.ForegroundRole:
+ # ForegroundRole color, must be QBrush
+ brush = qt.QBrush(qt.Qt.blue)
+ return brush
+ elif role == qt.Qt.CheckStateRole:
+ return bool(item[2]) # item[2] == True
+ elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
+ return ''
+ elif role == self.iconColorRole:
+ return item[1]['color']
+ elif role == self.iconLineWidthRole:
+ return item[1]['linewidth']
+ elif role == self.iconLineStyleRole:
+ return item[1]['linestyle']
+ elif role == self.iconSymbolRole:
+ return item[1]['symbol']
+ elif role == self.showLineRole:
+ return item[3]
+ elif role == self.showSymbolRole:
+ return item[4]
+ else:
+ _logger.info('Unkown role requested: %s', str(role))
+ return None
+
+ def setData(self, modelIndex, value, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ # raise IndexError('list index out of range')
+ _logger.warning(
+ 'setData -- List index out of range, idx: %d', idx)
+ return None
+
+ item = self.legendList[idx]
+ try:
+ if role == qt.Qt.DisplayRole:
+ # Set legend
+ item[0] = str(value)
+ elif role == self.iconColorRole:
+ item[1]['color'] = qt.QColor(value)
+ elif role == self.iconLineWidthRole:
+ item[1]['linewidth'] = int(value)
+ elif role == self.iconLineStyleRole:
+ item[1]['linestyle'] = str(value)
+ elif role == self.iconSymbolRole:
+ item[1]['symbol'] = str(value)
+ elif role == qt.Qt.CheckStateRole:
+ item[2] = value
+ elif role == self.showLineRole:
+ item[3] = value
+ elif role == self.showSymbolRole:
+ item[4] = value
+ except ValueError:
+ _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s',
+ str(value), str(role))
+ # Can that be right? Read docs again..
+ self.dataChanged.emit(modelIndex, modelIndex)
+ return True
+
+ def insertLegendList(self, row, llist):
+ """
+ :param int row: Determines after which row the items are inserted
+ :param llist: Carries the new legend information
+ :type llist: List
+ """
+ modelIndex = self.createIndex(row, 0)
+ count = len(llist)
+ super(LegendModel, self).beginInsertRows(modelIndex,
+ row,
+ row + count)
+ head = self.legendList[0:row]
+ tail = self.legendList[row:]
+ new = []
+ for (legend, icon) in llist:
+ linestyle = icon.get('linestyle', None)
+ if linestyle in NoLineStyle:
+ # Curve had no line, give it one and hide it
+ # So when toggle line, it will display a solid line
+ showLine = False
+ icon['linestyle'] = '-'
+ else:
+ showLine = True
+
+ symbol = icon.get('symbol', None)
+ if symbol in NoSymbols:
+ # Curve had no symbol, give it one and hide it
+ # So when toggle symbol, it will display 'o'
+ showSymbol = False
+ icon['symbol'] = 'o'
+ else:
+ showSymbol = True
+
+ selected = icon.get('selected', True)
+ item = [legend,
+ icon,
+ selected,
+ showLine,
+ showSymbol]
+ new.append(item)
+ self.legendList = head + new + tail
+ super(LegendModel, self).endInsertRows()
+ return True
+
+ def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
+ raise NotImplementedError('Use LegendModel.insertLegendList instead')
+
+ def removeRow(self, row):
+ return self.removeRows(row, 1)
+
+ def removeRows(self, row, count, modelIndex=qt.QModelIndex()):
+ length = len(self.legendList)
+ if length == 0:
+ # Nothing to do..
+ return True
+ if row < 0 or row >= length:
+ raise IndexError('Index out of range -- ' +
+ 'idx: %d, len: %d' % (row, length))
+ if count == 0:
+ return False
+ super(LegendModel, self).beginRemoveRows(modelIndex,
+ row,
+ row + count)
+ del(self.legendList[row:row + count])
+ super(LegendModel, self).endRemoveRows()
+ return True
+
+ def setEditor(self, event, editor):
+ """
+ :param str event: String that identifies the editor
+ :param editor: Widget used to change data in the underlying model
+ :type editor: QWidget
+ """
+ if event not in self.eventList:
+ raise ValueError('setEditor -- Event must be in %s' %
+ str(self.eventList))
+ self.editorDict[event] = editor
+
+
+class LegendListItemWidget(qt.QItemDelegate):
+ """Object displaying a single item (i.e., a row) in the list."""
+
+ # Notice: LegendListItem does NOT inherit
+ # from QObject, it cannot emit signals!
+
+ def __init__(self, parent=None, itemType=0):
+ super(LegendListItemWidget, self).__init__(parent)
+
+ # Dictionary to render checkboxes
+ self.cbDict = {}
+ self.labelDict = {}
+ self.iconDict = {}
+
+ # Keep checkbox and legend to get sizeHint
+ self.checkbox = qt.QCheckBox()
+ self.legend = qt.QLabel()
+ self.icon = LegendIcon()
+
+ # Context Menu and Editors
+ self.contextMenu = None
+
+ def paint(self, painter, option, modelIndex):
+ """
+ Here be docs..
+
+ :param QPainter painter:
+ :param QStyleOptionViewItem option:
+ :param QModelIndex modelIndex:
+ """
+ painter.save()
+ rect = option.rect
+
+ # Calculate the icon rectangle
+ iconSize = self.icon.sizeHint()
+ # Calculate icon position
+ x = rect.left() + 2
+ y = rect.top() + int(.5 * (rect.height() - iconSize.height()))
+ iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
+
+ # Calculate label rectangle
+ legendSize = qt.QSize(rect.width() - iconSize.width() - 30,
+ rect.height())
+ # Calculate label position
+ x = rect.left() + iconRect.width()
+ y = rect.top()
+ labelRect = qt.QRect(qt.QPoint(x, y), legendSize)
+ labelRect.translate(qt.QPoint(10, 0))
+
+ # Calculate the checkbox rectangle
+ x = rect.right() - 30
+ y = rect.top()
+ chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight())
+
+ # Remember the rectangles
+ idx = modelIndex.row()
+ self.cbDict[idx] = chBoxRect
+ self.iconDict[idx] = iconRect
+ self.labelDict[idx] = labelRect
+
+ # Draw background first!
+ if option.state & qt.QStyle.State_MouseOver:
+ backgroundBrush = option.palette.highlight()
+ else:
+ backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole)
+ painter.fillRect(rect, backgroundBrush)
+
+ # Draw label
+ legendText = modelIndex.data(qt.Qt.DisplayRole)
+ textBrush = modelIndex.data(qt.Qt.ForegroundRole)
+ textAlign = modelIndex.data(qt.Qt.TextAlignmentRole)
+ painter.setBrush(textBrush)
+ painter.setFont(self.legend.font())
+ painter.drawText(labelRect, textAlign, legendText)
+
+ # Draw icon
+ iconColor = modelIndex.data(LegendModel.iconColorRole)
+ iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ iconSymbol = modelIndex.data(LegendModel.iconSymbolRole)
+ icon = LegendIcon()
+ icon.resize(iconRect.size())
+ icon.move(iconRect.topRight())
+ icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole)
+ icon.showLine = modelIndex.data(LegendModel.showLineRole)
+ icon.setSymbolColor(iconColor)
+ icon.setLineColor(iconColor)
+ icon.setLineWidth(iconLineWidth)
+ icon.setLineStyle(iconLineStyle)
+ icon.setSymbol(iconSymbol)
+ icon.symbolOutlineBrush = backgroundBrush
+ icon.paint(painter, iconRect, option.palette)
+
+ # Draw the checkbox
+ if modelIndex.data(qt.Qt.CheckStateRole):
+ checkState = qt.Qt.Checked
+ else:
+ checkState = qt.Qt.Unchecked
+
+ self.drawCheck(
+ painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
+
+ painter.restore()
+
+ def editorEvent(self, event, model, option, modelIndex):
+ # From the docs:
+ # Mouse events are sent to editorEvent()
+ # even if they don't start editing of the item.
+ if event.button() == qt.Qt.RightButton and self.contextMenu:
+ self.contextMenu.exec_(event.globalPos(), modelIndex)
+ return True
+ elif event.button() == qt.Qt.LeftButton:
+ # Check if checkbox was clicked
+ idx = modelIndex.row()
+ cbRect = self.cbDict[idx]
+ if cbRect.contains(event.pos()):
+ # Toggle checkbox
+ model.setData(modelIndex,
+ not modelIndex.data(qt.Qt.CheckStateRole),
+ qt.Qt.CheckStateRole)
+ event.ignore()
+ return True
+ else:
+ return super(LegendListItemWidget, self).editorEvent(
+ event, model, option, modelIndex)
+
+ def createEditor(self, parent, option, idx):
+ _logger.info('### Editor request ###')
+
+ def sizeHint(self, option, idx):
+ # return qt.QSize(68,24)
+ iconSize = self.icon.sizeHint()
+ legendSize = self.legend.sizeHint()
+ checkboxSize = self.checkbox.sizeHint()
+ height = max([iconSize.height(),
+ legendSize.height(),
+ checkboxSize.height()]) + 4
+ width = iconSize.width() + legendSize.width() + checkboxSize.width()
+ return qt.QSize(width, height)
+
+
+class LegendListView(qt.QListView):
+ """Widget displaying a list of curve legends, line style and symbol."""
+
+ sigLegendSignal = qt.Signal(object)
+ """Signal emitting a dict when an action is triggered by the user."""
+
+ __mouseClickedEvent = 'mouseClicked'
+ __checkBoxClickedEvent = 'checkBoxClicked'
+ __legendClickedEvent = 'legendClicked'
+
+ def __init__(self, parent=None, model=None, contextMenu=None):
+ super(LegendListView, self).__init__(parent)
+ self.__lastButton = None
+ self.__lastClickPos = None
+ self.__lastModelIdx = None
+ # Set default delegate
+ self.setItemDelegate(LegendListItemWidget())
+ # Set default editors
+ # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding,
+ # qt.QSizePolicy.MinimumExpanding)
+ # Set edit triggers by hand using self.edit(QModelIndex)
+ # in mousePressEvent (better to control than signals)
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ # Control layout
+ # self.setBatchSize(2)
+ # self.setLayoutMode(qt.QListView.Batched)
+ # self.setFlow(qt.QListView.LeftToRight)
+
+ # Control selection
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ if model is None:
+ model = LegendModel()
+ self.setModel(model)
+ self.setContextMenu(contextMenu)
+
+ def setLegendList(self, legendList, row=None):
+ self.clear()
+ if row is None:
+ row = 0
+ model = self.model()
+ model.insertLegendList(row, legendList)
+ _logger.debug('LegendListView.setLegendList(legendList) finished')
+
+ def clear(self):
+ model = self.model()
+ model.removeRows(0, model.rowCount())
+ _logger.debug('LegendListView.clear() finished')
+
+ def setContextMenu(self, contextMenu=None):
+ delegate = self.itemDelegate()
+ if isinstance(delegate, LegendListItemWidget) and self.model():
+ if contextMenu is None:
+ delegate.contextMenu = LegendListContextMenu(self.model())
+ delegate.contextMenu.sigContextMenu.connect(
+ self._contextMenuSlot)
+ else:
+ delegate.contextMenu = contextMenu
+
+ def __getitem__(self, idx):
+ model = self.model()
+ try:
+ item = model[idx]
+ except ValueError:
+ item = None
+ return item
+
+ def _contextMenuSlot(self, ddict):
+ self.sigLegendSignal.emit(ddict)
+
+ def mousePressEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mousePressEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseDoubleClickEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mouseDoubleClickEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseMoveEvent(self, event):
+ # LegendListView.mouseMoveEvent is overwritten
+ # to suppress unwanted behavior in the delegate.
+ pass
+
+ def mouseReleaseEvent(self, event):
+ # LegendListView.mouseReleaseEvent is overwritten
+ # to subpress unwanted behavior in the delegate.
+ pass
+
+ def _handleMouseClick(self, modelIndex):
+ """
+ Distinguish between mouse click on Legend
+ and mouse click on CheckBox by setting the
+ currentCheckState attribute in LegendListItem.
+
+ Emits signal sigLegendSignal(ddict)
+
+ :param QModelIndex modelIndex: index of the clicked item
+ """
+ _logger.debug('self._handleMouseClick called')
+ if self.__lastButton not in [qt.Qt.LeftButton,
+ qt.Qt.RightButton]:
+ return
+ if not modelIndex.isValid():
+ _logger.debug('_handleMouseClick -- Invalid QModelIndex')
+ return
+ # model = self.model()
+ idx = modelIndex.row()
+
+ delegate = self.itemDelegate()
+ cbClicked = False
+ if isinstance(delegate, LegendListItemWidget):
+ for cbRect in delegate.cbDict.values():
+ if cbRect.contains(self.__lastPosition):
+ cbClicked = True
+ break
+
+ # TODO: Check for doubleclicks on legend/icon and spawn editors
+
+ ddict = {
+ 'legend': str(modelIndex.data(qt.Qt.DisplayRole)),
+ 'icon': {
+ 'linewidth': str(modelIndex.data(
+ LegendModel.iconLineWidthRole)),
+ 'linestyle': str(modelIndex.data(
+ LegendModel.iconLineStyleRole)),
+ 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole))
+ },
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data())
+ }
+ if self.__lastButton == qt.Qt.RightButton:
+ _logger.debug('Right clicked')
+ ddict['button'] = "right"
+ ddict['event'] = self.__mouseClickedEvent
+ elif cbClicked:
+ _logger.debug('CheckBox clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__checkBoxClickedEvent
+ else:
+ _logger.debug('Legend clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__legendClickedEvent
+ _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict))
+ self.sigLegendSignal.emit(ddict)
+
+
+class LegendListContextMenu(qt.QMenu):
+ """Contextual menu associated to items in a :class:`LegendListView`."""
+
+ sigContextMenu = qt.Signal(object)
+ """Signal emitting a dict upon contextual menu actions."""
+
+ def __init__(self, model):
+ super(LegendListContextMenu, self).__init__(parent=None)
+ self.model = model
+
+ self.addAction('Set Active', self.setActiveAction)
+ self.addAction('Map to left', self.mapToLeftAction)
+ self.addAction('Map to right', self.mapToRightAction)
+
+ self._pointsAction = self.addAction(
+ 'Points', self.togglePointsAction)
+ self._pointsAction.setCheckable(True)
+
+ self._linesAction = self.addAction('Lines', self.toggleLinesAction)
+ self._linesAction.setCheckable(True)
+
+ self.addAction('Remove curve', self.removeItemAction)
+ self.addAction('Rename curve', self.renameItemAction)
+
+ def exec_(self, pos, idx):
+ self.__currentIdx = idx
+
+ # Set checkable action state
+ modelIndex = self.currentIdx()
+ self._pointsAction.setChecked(
+ modelIndex.data(LegendModel.showSymbolRole))
+ self._linesAction.setChecked(
+ modelIndex.data(LegendModel.showLineRole))
+
+ super(LegendListContextMenu, self).popup(pos)
+
+ def currentIdx(self):
+ return self.__currentIdx
+
+ def mapToLeftAction(self):
+ _logger.debug('LegendListContextMenu.mapToLeftAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToLeft"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def mapToRightAction(self):
+ _logger.debug('LegendListContextMenu.mapToRightAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToRight"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def removeItemAction(self):
+ _logger.debug('LegendListContextMenu.removeCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "removeCurve"
+ }
+ self.model.removeRow(modelIndex.row())
+ self.sigContextMenu.emit(ddict)
+
+ def renameItemAction(self):
+ _logger.debug('LegendListContextMenu.renameCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "renameCurve"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def toggleLinesAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ visible = not modelIndex.data(LegendModel.showLineRole)
+ _logger.debug('toggleLinesAction -- lines visible: %s', str(visible))
+ ddict['event'] = "toggleLine"
+ ddict['line'] = visible
+ ddict['linestyle'] = linestyle if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showLineRole)
+ self.sigContextMenu.emit(ddict)
+
+ def togglePointsAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ flag = modelIndex.data(LegendModel.showSymbolRole)
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ visible = not flag or symbol in NoSymbols
+ _logger.debug(
+ 'togglePointsAction -- Symbols visible: %s', str(visible))
+
+ ddict['event'] = "togglePoints"
+ ddict['points'] = visible
+ ddict['symbol'] = symbol if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ self.sigContextMenu.emit(ddict)
+
+ def setActiveAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ _logger.debug('setActiveAction -- active curve: %s', legend)
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "setActiveCurve",
+ }
+ self.sigContextMenu.emit(ddict)
+
+
+class RenameCurveDialog(qt.QDialog):
+ """Dialog box to input the name of a curve."""
+
+ def __init__(self, parent=None, current="", curves=()):
+ super(RenameCurveDialog, self).__init__(parent)
+ self.setWindowTitle("Rename Curve %s" % current)
+ self.curves = curves
+ layout = qt.QVBoxLayout(self)
+ self.lineEdit = qt.QLineEdit(self)
+ self.lineEdit.setText(current)
+ self.hbox = qt.QWidget(self)
+ self.hboxLayout = qt.QHBoxLayout(self.hbox)
+ self.hboxLayout.addStretch(1)
+ self.okButton = qt.QPushButton(self.hbox)
+ self.okButton.setText('OK')
+ self.hboxLayout.addWidget(self.okButton)
+ self.cancelButton = qt.QPushButton(self.hbox)
+ self.cancelButton.setText('Cancel')
+ self.hboxLayout.addWidget(self.cancelButton)
+ self.hboxLayout.addStretch(1)
+ layout.addWidget(self.lineEdit)
+ layout.addWidget(self.hbox)
+ self.okButton.clicked.connect(self.preAccept)
+ self.cancelButton.clicked.connect(self.reject)
+
+ def preAccept(self):
+ text = str(self.lineEdit.text())
+ addedText = ""
+ if len(text):
+ if text not in self.curves:
+ self.accept()
+ return
+ else:
+ addedText = "Curve already exists."
+ text = "Invalid Curve Name"
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle(text)
+ text += "\n%s" % addedText
+ msg.setText(text)
+ msg.exec_()
+
+ def getText(self):
+ return str(self.lineEdit.text())
+
+
+class LegendsDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow.
+
+ It makes the link between the LegendListView widget and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ """
+
+ def __init__(self, parent=None, plot=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._isConnected = False # True if widget connected to plot signals
+
+ super(LegendsDockWidget, self).__init__("Legends", parent)
+
+ self._legendWidget = LegendListView()
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self._legendWidget)
+
+ self.visibilityChanged.connect(
+ self._visibilityChangedHandler)
+
+ self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plotRef()
+
+ def renameCurve(self, oldLegend, newLegend):
+ """Change the name of a curve using remove and addCurve
+
+ :param str oldLegend: The legend of the curve to be change
+ :param str newLegend: The new legend of the curve
+ """
+ curve = self.plot.getCurve(oldLegend)
+ self.plot.remove(oldLegend, kind='curve')
+ self.plot.addCurve(curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ legend=newLegend,
+ info=curve.getInfo(),
+ color=curve.getColor(),
+ symbol=curve.getSymbol(),
+ linewidth=curve.getLineWidth(),
+ linestyle=curve.getLineStyle(),
+ xlabel=curve.getXLabel(),
+ ylabel=curve.getYLabel(),
+ xerror=curve.getXErrorData(copy=False),
+ yerror=curve.getYErrorData(copy=False),
+ z=curve.getZValue(),
+ selectable=curve.isSelectable(),
+ fill=curve.isFill(),
+ resetzoom=False)
+
+ def _legendSignalHandler(self, ddict):
+ """Handles events from the LegendListView signal"""
+ _logger.debug("Legend signal ddict = %s", str(ddict))
+
+ if ddict['event'] == "legendClicked":
+ if ddict['button'] == "left":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "removeCurve":
+ self.plot.removeCurve(ddict['legend'])
+
+ elif ddict['event'] == "renameCurve":
+ curveList = self.plot.getAllCurves(just_legend=True)
+ oldLegend = ddict['legend']
+ dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
+ ret = dialog.exec_()
+ if ret:
+ newLegend = dialog.getText()
+ self.renameCurve(oldLegend, newLegend)
+
+ elif ddict['event'] == "setActiveCurve":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "checkBoxClicked":
+ self.plot.hideCurve(ddict['legend'], not ddict['selected'])
+
+ elif ddict['event'] in ["mapToRight", "mapToLeft"]:
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left'
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getLegend(),
+ info=curve.getInfo(),
+ yaxis=yaxis)
+
+ elif ddict['event'] == "togglePoints":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ symbol = ddict['symbol'] if ddict['points'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getLegend(),
+ info=curve.getInfo(),
+ symbol=symbol)
+
+ elif ddict['event'] == "toggleLine":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ linestyle = ddict['linestyle'] if ddict['line'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getLegend(),
+ info=curve.getInfo(),
+ linestyle=linestyle)
+
+ else:
+ _logger.debug("unhandled event %s", str(ddict['event']))
+
+ def updateLegends(self, *args):
+ """Sync the LegendSelector widget displayed info with the plot.
+ """
+ legendList = []
+ for curve in self.plot.getAllCurves(withhidden=True):
+ legend = curve.getLegend()
+ # Use active color if curve is active
+ if legend == self.plot.getActiveCurve(just_legend=True):
+ color = qt.QColor(self.plot.getActiveCurveColor())
+ else:
+ color = qt.QColor.fromRgbF(*curve.getColor())
+
+ curveInfo = {
+ 'color': color,
+ 'linewidth': curve.getLineWidth(),
+ 'linestyle': curve.getLineStyle(),
+ 'symbol': curve.getSymbol(),
+ 'selected': not self.plot.isCurveHidden(legend)}
+ legendList.append((legend, curveInfo))
+
+ self._legendWidget.setLegendList(legendList)
+
+ def _visibilityChangedHandler(self, visible):
+ if visible:
+ self.updateLegends()
+ if not self._isConnected:
+ self.plot.sigContentChanged.connect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.connect(self.updateLegends)
+ self._isConnected = True
+ else:
+ if self._isConnected:
+ self.plot.sigContentChanged.disconnect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.disconnect(self.updateLegends)
+ self._isConnected = False
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/silx/gui/plot/MPLColormap.py b/silx/gui/plot/MPLColormap.py
new file mode 100644
index 0000000..49b11d7
--- /dev/null
+++ b/silx/gui/plot/MPLColormap.py
@@ -0,0 +1,1062 @@
+# New matplotlib colormaps by Nathaniel J. Smith, Stefan van der Walt,
+# and (in the case of viridis) Eric Firing.
+#
+# This file and the colormaps in it are released under the CC0 license /
+# public domain dedication. We would appreciate credit if you use or
+# redistribute these colormaps, but do not impose any legal restrictions.
+#
+# To the extent possible under law, the persons who associated CC0 with
+# mpl-colormaps have waived all copyright and related or neighboring rights
+# to mpl-colormaps.
+#
+# You should have received a copy of the CC0 legalcode along with this
+# work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
+"""Matplotlib's new colormaps"""
+
+
+from matplotlib.colors import ListedColormap
+
+
+__all__ = ['magma', 'inferno', 'plasma', 'viridis']
+
+_magma_data = [[0.001462, 0.000466, 0.013866],
+ [0.002258, 0.001295, 0.018331],
+ [0.003279, 0.002305, 0.023708],
+ [0.004512, 0.003490, 0.029965],
+ [0.005950, 0.004843, 0.037130],
+ [0.007588, 0.006356, 0.044973],
+ [0.009426, 0.008022, 0.052844],
+ [0.011465, 0.009828, 0.060750],
+ [0.013708, 0.011771, 0.068667],
+ [0.016156, 0.013840, 0.076603],
+ [0.018815, 0.016026, 0.084584],
+ [0.021692, 0.018320, 0.092610],
+ [0.024792, 0.020715, 0.100676],
+ [0.028123, 0.023201, 0.108787],
+ [0.031696, 0.025765, 0.116965],
+ [0.035520, 0.028397, 0.125209],
+ [0.039608, 0.031090, 0.133515],
+ [0.043830, 0.033830, 0.141886],
+ [0.048062, 0.036607, 0.150327],
+ [0.052320, 0.039407, 0.158841],
+ [0.056615, 0.042160, 0.167446],
+ [0.060949, 0.044794, 0.176129],
+ [0.065330, 0.047318, 0.184892],
+ [0.069764, 0.049726, 0.193735],
+ [0.074257, 0.052017, 0.202660],
+ [0.078815, 0.054184, 0.211667],
+ [0.083446, 0.056225, 0.220755],
+ [0.088155, 0.058133, 0.229922],
+ [0.092949, 0.059904, 0.239164],
+ [0.097833, 0.061531, 0.248477],
+ [0.102815, 0.063010, 0.257854],
+ [0.107899, 0.064335, 0.267289],
+ [0.113094, 0.065492, 0.276784],
+ [0.118405, 0.066479, 0.286321],
+ [0.123833, 0.067295, 0.295879],
+ [0.129380, 0.067935, 0.305443],
+ [0.135053, 0.068391, 0.315000],
+ [0.140858, 0.068654, 0.324538],
+ [0.146785, 0.068738, 0.334011],
+ [0.152839, 0.068637, 0.343404],
+ [0.159018, 0.068354, 0.352688],
+ [0.165308, 0.067911, 0.361816],
+ [0.171713, 0.067305, 0.370771],
+ [0.178212, 0.066576, 0.379497],
+ [0.184801, 0.065732, 0.387973],
+ [0.191460, 0.064818, 0.396152],
+ [0.198177, 0.063862, 0.404009],
+ [0.204935, 0.062907, 0.411514],
+ [0.211718, 0.061992, 0.418647],
+ [0.218512, 0.061158, 0.425392],
+ [0.225302, 0.060445, 0.431742],
+ [0.232077, 0.059889, 0.437695],
+ [0.238826, 0.059517, 0.443256],
+ [0.245543, 0.059352, 0.448436],
+ [0.252220, 0.059415, 0.453248],
+ [0.258857, 0.059706, 0.457710],
+ [0.265447, 0.060237, 0.461840],
+ [0.271994, 0.060994, 0.465660],
+ [0.278493, 0.061978, 0.469190],
+ [0.284951, 0.063168, 0.472451],
+ [0.291366, 0.064553, 0.475462],
+ [0.297740, 0.066117, 0.478243],
+ [0.304081, 0.067835, 0.480812],
+ [0.310382, 0.069702, 0.483186],
+ [0.316654, 0.071690, 0.485380],
+ [0.322899, 0.073782, 0.487408],
+ [0.329114, 0.075972, 0.489287],
+ [0.335308, 0.078236, 0.491024],
+ [0.341482, 0.080564, 0.492631],
+ [0.347636, 0.082946, 0.494121],
+ [0.353773, 0.085373, 0.495501],
+ [0.359898, 0.087831, 0.496778],
+ [0.366012, 0.090314, 0.497960],
+ [0.372116, 0.092816, 0.499053],
+ [0.378211, 0.095332, 0.500067],
+ [0.384299, 0.097855, 0.501002],
+ [0.390384, 0.100379, 0.501864],
+ [0.396467, 0.102902, 0.502658],
+ [0.402548, 0.105420, 0.503386],
+ [0.408629, 0.107930, 0.504052],
+ [0.414709, 0.110431, 0.504662],
+ [0.420791, 0.112920, 0.505215],
+ [0.426877, 0.115395, 0.505714],
+ [0.432967, 0.117855, 0.506160],
+ [0.439062, 0.120298, 0.506555],
+ [0.445163, 0.122724, 0.506901],
+ [0.451271, 0.125132, 0.507198],
+ [0.457386, 0.127522, 0.507448],
+ [0.463508, 0.129893, 0.507652],
+ [0.469640, 0.132245, 0.507809],
+ [0.475780, 0.134577, 0.507921],
+ [0.481929, 0.136891, 0.507989],
+ [0.488088, 0.139186, 0.508011],
+ [0.494258, 0.141462, 0.507988],
+ [0.500438, 0.143719, 0.507920],
+ [0.506629, 0.145958, 0.507806],
+ [0.512831, 0.148179, 0.507648],
+ [0.519045, 0.150383, 0.507443],
+ [0.525270, 0.152569, 0.507192],
+ [0.531507, 0.154739, 0.506895],
+ [0.537755, 0.156894, 0.506551],
+ [0.544015, 0.159033, 0.506159],
+ [0.550287, 0.161158, 0.505719],
+ [0.556571, 0.163269, 0.505230],
+ [0.562866, 0.165368, 0.504692],
+ [0.569172, 0.167454, 0.504105],
+ [0.575490, 0.169530, 0.503466],
+ [0.581819, 0.171596, 0.502777],
+ [0.588158, 0.173652, 0.502035],
+ [0.594508, 0.175701, 0.501241],
+ [0.600868, 0.177743, 0.500394],
+ [0.607238, 0.179779, 0.499492],
+ [0.613617, 0.181811, 0.498536],
+ [0.620005, 0.183840, 0.497524],
+ [0.626401, 0.185867, 0.496456],
+ [0.632805, 0.187893, 0.495332],
+ [0.639216, 0.189921, 0.494150],
+ [0.645633, 0.191952, 0.492910],
+ [0.652056, 0.193986, 0.491611],
+ [0.658483, 0.196027, 0.490253],
+ [0.664915, 0.198075, 0.488836],
+ [0.671349, 0.200133, 0.487358],
+ [0.677786, 0.202203, 0.485819],
+ [0.684224, 0.204286, 0.484219],
+ [0.690661, 0.206384, 0.482558],
+ [0.697098, 0.208501, 0.480835],
+ [0.703532, 0.210638, 0.479049],
+ [0.709962, 0.212797, 0.477201],
+ [0.716387, 0.214982, 0.475290],
+ [0.722805, 0.217194, 0.473316],
+ [0.729216, 0.219437, 0.471279],
+ [0.735616, 0.221713, 0.469180],
+ [0.742004, 0.224025, 0.467018],
+ [0.748378, 0.226377, 0.464794],
+ [0.754737, 0.228772, 0.462509],
+ [0.761077, 0.231214, 0.460162],
+ [0.767398, 0.233705, 0.457755],
+ [0.773695, 0.236249, 0.455289],
+ [0.779968, 0.238851, 0.452765],
+ [0.786212, 0.241514, 0.450184],
+ [0.792427, 0.244242, 0.447543],
+ [0.798608, 0.247040, 0.444848],
+ [0.804752, 0.249911, 0.442102],
+ [0.810855, 0.252861, 0.439305],
+ [0.816914, 0.255895, 0.436461],
+ [0.822926, 0.259016, 0.433573],
+ [0.828886, 0.262229, 0.430644],
+ [0.834791, 0.265540, 0.427671],
+ [0.840636, 0.268953, 0.424666],
+ [0.846416, 0.272473, 0.421631],
+ [0.852126, 0.276106, 0.418573],
+ [0.857763, 0.279857, 0.415496],
+ [0.863320, 0.283729, 0.412403],
+ [0.868793, 0.287728, 0.409303],
+ [0.874176, 0.291859, 0.406205],
+ [0.879464, 0.296125, 0.403118],
+ [0.884651, 0.300530, 0.400047],
+ [0.889731, 0.305079, 0.397002],
+ [0.894700, 0.309773, 0.393995],
+ [0.899552, 0.314616, 0.391037],
+ [0.904281, 0.319610, 0.388137],
+ [0.908884, 0.324755, 0.385308],
+ [0.913354, 0.330052, 0.382563],
+ [0.917689, 0.335500, 0.379915],
+ [0.921884, 0.341098, 0.377376],
+ [0.925937, 0.346844, 0.374959],
+ [0.929845, 0.352734, 0.372677],
+ [0.933606, 0.358764, 0.370541],
+ [0.937221, 0.364929, 0.368567],
+ [0.940687, 0.371224, 0.366762],
+ [0.944006, 0.377643, 0.365136],
+ [0.947180, 0.384178, 0.363701],
+ [0.950210, 0.390820, 0.362468],
+ [0.953099, 0.397563, 0.361438],
+ [0.955849, 0.404400, 0.360619],
+ [0.958464, 0.411324, 0.360014],
+ [0.960949, 0.418323, 0.359630],
+ [0.963310, 0.425390, 0.359469],
+ [0.965549, 0.432519, 0.359529],
+ [0.967671, 0.439703, 0.359810],
+ [0.969680, 0.446936, 0.360311],
+ [0.971582, 0.454210, 0.361030],
+ [0.973381, 0.461520, 0.361965],
+ [0.975082, 0.468861, 0.363111],
+ [0.976690, 0.476226, 0.364466],
+ [0.978210, 0.483612, 0.366025],
+ [0.979645, 0.491014, 0.367783],
+ [0.981000, 0.498428, 0.369734],
+ [0.982279, 0.505851, 0.371874],
+ [0.983485, 0.513280, 0.374198],
+ [0.984622, 0.520713, 0.376698],
+ [0.985693, 0.528148, 0.379371],
+ [0.986700, 0.535582, 0.382210],
+ [0.987646, 0.543015, 0.385210],
+ [0.988533, 0.550446, 0.388365],
+ [0.989363, 0.557873, 0.391671],
+ [0.990138, 0.565296, 0.395122],
+ [0.990871, 0.572706, 0.398714],
+ [0.991558, 0.580107, 0.402441],
+ [0.992196, 0.587502, 0.406299],
+ [0.992785, 0.594891, 0.410283],
+ [0.993326, 0.602275, 0.414390],
+ [0.993834, 0.609644, 0.418613],
+ [0.994309, 0.616999, 0.422950],
+ [0.994738, 0.624350, 0.427397],
+ [0.995122, 0.631696, 0.431951],
+ [0.995480, 0.639027, 0.436607],
+ [0.995810, 0.646344, 0.441361],
+ [0.996096, 0.653659, 0.446213],
+ [0.996341, 0.660969, 0.451160],
+ [0.996580, 0.668256, 0.456192],
+ [0.996775, 0.675541, 0.461314],
+ [0.996925, 0.682828, 0.466526],
+ [0.997077, 0.690088, 0.471811],
+ [0.997186, 0.697349, 0.477182],
+ [0.997254, 0.704611, 0.482635],
+ [0.997325, 0.711848, 0.488154],
+ [0.997351, 0.719089, 0.493755],
+ [0.997351, 0.726324, 0.499428],
+ [0.997341, 0.733545, 0.505167],
+ [0.997285, 0.740772, 0.510983],
+ [0.997228, 0.747981, 0.516859],
+ [0.997138, 0.755190, 0.522806],
+ [0.997019, 0.762398, 0.528821],
+ [0.996898, 0.769591, 0.534892],
+ [0.996727, 0.776795, 0.541039],
+ [0.996571, 0.783977, 0.547233],
+ [0.996369, 0.791167, 0.553499],
+ [0.996162, 0.798348, 0.559820],
+ [0.995932, 0.805527, 0.566202],
+ [0.995680, 0.812706, 0.572645],
+ [0.995424, 0.819875, 0.579140],
+ [0.995131, 0.827052, 0.585701],
+ [0.994851, 0.834213, 0.592307],
+ [0.994524, 0.841387, 0.598983],
+ [0.994222, 0.848540, 0.605696],
+ [0.993866, 0.855711, 0.612482],
+ [0.993545, 0.862859, 0.619299],
+ [0.993170, 0.870024, 0.626189],
+ [0.992831, 0.877168, 0.633109],
+ [0.992440, 0.884330, 0.640099],
+ [0.992089, 0.891470, 0.647116],
+ [0.991688, 0.898627, 0.654202],
+ [0.991332, 0.905763, 0.661309],
+ [0.990930, 0.912915, 0.668481],
+ [0.990570, 0.920049, 0.675675],
+ [0.990175, 0.927196, 0.682926],
+ [0.989815, 0.934329, 0.690198],
+ [0.989434, 0.941470, 0.697519],
+ [0.989077, 0.948604, 0.704863],
+ [0.988717, 0.955742, 0.712242],
+ [0.988367, 0.962878, 0.719649],
+ [0.988033, 0.970012, 0.727077],
+ [0.987691, 0.977154, 0.734536],
+ [0.987387, 0.984288, 0.742002],
+ [0.987053, 0.991438, 0.749504]]
+
+_inferno_data = [[0.001462, 0.000466, 0.013866],
+ [0.002267, 0.001270, 0.018570],
+ [0.003299, 0.002249, 0.024239],
+ [0.004547, 0.003392, 0.030909],
+ [0.006006, 0.004692, 0.038558],
+ [0.007676, 0.006136, 0.046836],
+ [0.009561, 0.007713, 0.055143],
+ [0.011663, 0.009417, 0.063460],
+ [0.013995, 0.011225, 0.071862],
+ [0.016561, 0.013136, 0.080282],
+ [0.019373, 0.015133, 0.088767],
+ [0.022447, 0.017199, 0.097327],
+ [0.025793, 0.019331, 0.105930],
+ [0.029432, 0.021503, 0.114621],
+ [0.033385, 0.023702, 0.123397],
+ [0.037668, 0.025921, 0.132232],
+ [0.042253, 0.028139, 0.141141],
+ [0.046915, 0.030324, 0.150164],
+ [0.051644, 0.032474, 0.159254],
+ [0.056449, 0.034569, 0.168414],
+ [0.061340, 0.036590, 0.177642],
+ [0.066331, 0.038504, 0.186962],
+ [0.071429, 0.040294, 0.196354],
+ [0.076637, 0.041905, 0.205799],
+ [0.081962, 0.043328, 0.215289],
+ [0.087411, 0.044556, 0.224813],
+ [0.092990, 0.045583, 0.234358],
+ [0.098702, 0.046402, 0.243904],
+ [0.104551, 0.047008, 0.253430],
+ [0.110536, 0.047399, 0.262912],
+ [0.116656, 0.047574, 0.272321],
+ [0.122908, 0.047536, 0.281624],
+ [0.129285, 0.047293, 0.290788],
+ [0.135778, 0.046856, 0.299776],
+ [0.142378, 0.046242, 0.308553],
+ [0.149073, 0.045468, 0.317085],
+ [0.155850, 0.044559, 0.325338],
+ [0.162689, 0.043554, 0.333277],
+ [0.169575, 0.042489, 0.340874],
+ [0.176493, 0.041402, 0.348111],
+ [0.183429, 0.040329, 0.354971],
+ [0.190367, 0.039309, 0.361447],
+ [0.197297, 0.038400, 0.367535],
+ [0.204209, 0.037632, 0.373238],
+ [0.211095, 0.037030, 0.378563],
+ [0.217949, 0.036615, 0.383522],
+ [0.224763, 0.036405, 0.388129],
+ [0.231538, 0.036405, 0.392400],
+ [0.238273, 0.036621, 0.396353],
+ [0.244967, 0.037055, 0.400007],
+ [0.251620, 0.037705, 0.403378],
+ [0.258234, 0.038571, 0.406485],
+ [0.264810, 0.039647, 0.409345],
+ [0.271347, 0.040922, 0.411976],
+ [0.277850, 0.042353, 0.414392],
+ [0.284321, 0.043933, 0.416608],
+ [0.290763, 0.045644, 0.418637],
+ [0.297178, 0.047470, 0.420491],
+ [0.303568, 0.049396, 0.422182],
+ [0.309935, 0.051407, 0.423721],
+ [0.316282, 0.053490, 0.425116],
+ [0.322610, 0.055634, 0.426377],
+ [0.328921, 0.057827, 0.427511],
+ [0.335217, 0.060060, 0.428524],
+ [0.341500, 0.062325, 0.429425],
+ [0.347771, 0.064616, 0.430217],
+ [0.354032, 0.066925, 0.430906],
+ [0.360284, 0.069247, 0.431497],
+ [0.366529, 0.071579, 0.431994],
+ [0.372768, 0.073915, 0.432400],
+ [0.379001, 0.076253, 0.432719],
+ [0.385228, 0.078591, 0.432955],
+ [0.391453, 0.080927, 0.433109],
+ [0.397674, 0.083257, 0.433183],
+ [0.403894, 0.085580, 0.433179],
+ [0.410113, 0.087896, 0.433098],
+ [0.416331, 0.090203, 0.432943],
+ [0.422549, 0.092501, 0.432714],
+ [0.428768, 0.094790, 0.432412],
+ [0.434987, 0.097069, 0.432039],
+ [0.441207, 0.099338, 0.431594],
+ [0.447428, 0.101597, 0.431080],
+ [0.453651, 0.103848, 0.430498],
+ [0.459875, 0.106089, 0.429846],
+ [0.466100, 0.108322, 0.429125],
+ [0.472328, 0.110547, 0.428334],
+ [0.478558, 0.112764, 0.427475],
+ [0.484789, 0.114974, 0.426548],
+ [0.491022, 0.117179, 0.425552],
+ [0.497257, 0.119379, 0.424488],
+ [0.503493, 0.121575, 0.423356],
+ [0.509730, 0.123769, 0.422156],
+ [0.515967, 0.125960, 0.420887],
+ [0.522206, 0.128150, 0.419549],
+ [0.528444, 0.130341, 0.418142],
+ [0.534683, 0.132534, 0.416667],
+ [0.540920, 0.134729, 0.415123],
+ [0.547157, 0.136929, 0.413511],
+ [0.553392, 0.139134, 0.411829],
+ [0.559624, 0.141346, 0.410078],
+ [0.565854, 0.143567, 0.408258],
+ [0.572081, 0.145797, 0.406369],
+ [0.578304, 0.148039, 0.404411],
+ [0.584521, 0.150294, 0.402385],
+ [0.590734, 0.152563, 0.400290],
+ [0.596940, 0.154848, 0.398125],
+ [0.603139, 0.157151, 0.395891],
+ [0.609330, 0.159474, 0.393589],
+ [0.615513, 0.161817, 0.391219],
+ [0.621685, 0.164184, 0.388781],
+ [0.627847, 0.166575, 0.386276],
+ [0.633998, 0.168992, 0.383704],
+ [0.640135, 0.171438, 0.381065],
+ [0.646260, 0.173914, 0.378359],
+ [0.652369, 0.176421, 0.375586],
+ [0.658463, 0.178962, 0.372748],
+ [0.664540, 0.181539, 0.369846],
+ [0.670599, 0.184153, 0.366879],
+ [0.676638, 0.186807, 0.363849],
+ [0.682656, 0.189501, 0.360757],
+ [0.688653, 0.192239, 0.357603],
+ [0.694627, 0.195021, 0.354388],
+ [0.700576, 0.197851, 0.351113],
+ [0.706500, 0.200728, 0.347777],
+ [0.712396, 0.203656, 0.344383],
+ [0.718264, 0.206636, 0.340931],
+ [0.724103, 0.209670, 0.337424],
+ [0.729909, 0.212759, 0.333861],
+ [0.735683, 0.215906, 0.330245],
+ [0.741423, 0.219112, 0.326576],
+ [0.747127, 0.222378, 0.322856],
+ [0.752794, 0.225706, 0.319085],
+ [0.758422, 0.229097, 0.315266],
+ [0.764010, 0.232554, 0.311399],
+ [0.769556, 0.236077, 0.307485],
+ [0.775059, 0.239667, 0.303526],
+ [0.780517, 0.243327, 0.299523],
+ [0.785929, 0.247056, 0.295477],
+ [0.791293, 0.250856, 0.291390],
+ [0.796607, 0.254728, 0.287264],
+ [0.801871, 0.258674, 0.283099],
+ [0.807082, 0.262692, 0.278898],
+ [0.812239, 0.266786, 0.274661],
+ [0.817341, 0.270954, 0.270390],
+ [0.822386, 0.275197, 0.266085],
+ [0.827372, 0.279517, 0.261750],
+ [0.832299, 0.283913, 0.257383],
+ [0.837165, 0.288385, 0.252988],
+ [0.841969, 0.292933, 0.248564],
+ [0.846709, 0.297559, 0.244113],
+ [0.851384, 0.302260, 0.239636],
+ [0.855992, 0.307038, 0.235133],
+ [0.860533, 0.311892, 0.230606],
+ [0.865006, 0.316822, 0.226055],
+ [0.869409, 0.321827, 0.221482],
+ [0.873741, 0.326906, 0.216886],
+ [0.878001, 0.332060, 0.212268],
+ [0.882188, 0.337287, 0.207628],
+ [0.886302, 0.342586, 0.202968],
+ [0.890341, 0.347957, 0.198286],
+ [0.894305, 0.353399, 0.193584],
+ [0.898192, 0.358911, 0.188860],
+ [0.902003, 0.364492, 0.184116],
+ [0.905735, 0.370140, 0.179350],
+ [0.909390, 0.375856, 0.174563],
+ [0.912966, 0.381636, 0.169755],
+ [0.916462, 0.387481, 0.164924],
+ [0.919879, 0.393389, 0.160070],
+ [0.923215, 0.399359, 0.155193],
+ [0.926470, 0.405389, 0.150292],
+ [0.929644, 0.411479, 0.145367],
+ [0.932737, 0.417627, 0.140417],
+ [0.935747, 0.423831, 0.135440],
+ [0.938675, 0.430091, 0.130438],
+ [0.941521, 0.436405, 0.125409],
+ [0.944285, 0.442772, 0.120354],
+ [0.946965, 0.449191, 0.115272],
+ [0.949562, 0.455660, 0.110164],
+ [0.952075, 0.462178, 0.105031],
+ [0.954506, 0.468744, 0.099874],
+ [0.956852, 0.475356, 0.094695],
+ [0.959114, 0.482014, 0.089499],
+ [0.961293, 0.488716, 0.084289],
+ [0.963387, 0.495462, 0.079073],
+ [0.965397, 0.502249, 0.073859],
+ [0.967322, 0.509078, 0.068659],
+ [0.969163, 0.515946, 0.063488],
+ [0.970919, 0.522853, 0.058367],
+ [0.972590, 0.529798, 0.053324],
+ [0.974176, 0.536780, 0.048392],
+ [0.975677, 0.543798, 0.043618],
+ [0.977092, 0.550850, 0.039050],
+ [0.978422, 0.557937, 0.034931],
+ [0.979666, 0.565057, 0.031409],
+ [0.980824, 0.572209, 0.028508],
+ [0.981895, 0.579392, 0.026250],
+ [0.982881, 0.586606, 0.024661],
+ [0.983779, 0.593849, 0.023770],
+ [0.984591, 0.601122, 0.023606],
+ [0.985315, 0.608422, 0.024202],
+ [0.985952, 0.615750, 0.025592],
+ [0.986502, 0.623105, 0.027814],
+ [0.986964, 0.630485, 0.030908],
+ [0.987337, 0.637890, 0.034916],
+ [0.987622, 0.645320, 0.039886],
+ [0.987819, 0.652773, 0.045581],
+ [0.987926, 0.660250, 0.051750],
+ [0.987945, 0.667748, 0.058329],
+ [0.987874, 0.675267, 0.065257],
+ [0.987714, 0.682807, 0.072489],
+ [0.987464, 0.690366, 0.079990],
+ [0.987124, 0.697944, 0.087731],
+ [0.986694, 0.705540, 0.095694],
+ [0.986175, 0.713153, 0.103863],
+ [0.985566, 0.720782, 0.112229],
+ [0.984865, 0.728427, 0.120785],
+ [0.984075, 0.736087, 0.129527],
+ [0.983196, 0.743758, 0.138453],
+ [0.982228, 0.751442, 0.147565],
+ [0.981173, 0.759135, 0.156863],
+ [0.980032, 0.766837, 0.166353],
+ [0.978806, 0.774545, 0.176037],
+ [0.977497, 0.782258, 0.185923],
+ [0.976108, 0.789974, 0.196018],
+ [0.974638, 0.797692, 0.206332],
+ [0.973088, 0.805409, 0.216877],
+ [0.971468, 0.813122, 0.227658],
+ [0.969783, 0.820825, 0.238686],
+ [0.968041, 0.828515, 0.249972],
+ [0.966243, 0.836191, 0.261534],
+ [0.964394, 0.843848, 0.273391],
+ [0.962517, 0.851476, 0.285546],
+ [0.960626, 0.859069, 0.298010],
+ [0.958720, 0.866624, 0.310820],
+ [0.956834, 0.874129, 0.323974],
+ [0.954997, 0.881569, 0.337475],
+ [0.953215, 0.888942, 0.351369],
+ [0.951546, 0.896226, 0.365627],
+ [0.950018, 0.903409, 0.380271],
+ [0.948683, 0.910473, 0.395289],
+ [0.947594, 0.917399, 0.410665],
+ [0.946809, 0.924168, 0.426373],
+ [0.946392, 0.930761, 0.442367],
+ [0.946403, 0.937159, 0.458592],
+ [0.946903, 0.943348, 0.474970],
+ [0.947937, 0.949318, 0.491426],
+ [0.949545, 0.955063, 0.507860],
+ [0.951740, 0.960587, 0.524203],
+ [0.954529, 0.965896, 0.540361],
+ [0.957896, 0.971003, 0.556275],
+ [0.961812, 0.975924, 0.571925],
+ [0.966249, 0.980678, 0.587206],
+ [0.971162, 0.985282, 0.602154],
+ [0.976511, 0.989753, 0.616760],
+ [0.982257, 0.994109, 0.631017],
+ [0.988362, 0.998364, 0.644924]]
+
+_plasma_data = [[0.050383, 0.029803, 0.527975],
+ [0.063536, 0.028426, 0.533124],
+ [0.075353, 0.027206, 0.538007],
+ [0.086222, 0.026125, 0.542658],
+ [0.096379, 0.025165, 0.547103],
+ [0.105980, 0.024309, 0.551368],
+ [0.115124, 0.023556, 0.555468],
+ [0.123903, 0.022878, 0.559423],
+ [0.132381, 0.022258, 0.563250],
+ [0.140603, 0.021687, 0.566959],
+ [0.148607, 0.021154, 0.570562],
+ [0.156421, 0.020651, 0.574065],
+ [0.164070, 0.020171, 0.577478],
+ [0.171574, 0.019706, 0.580806],
+ [0.178950, 0.019252, 0.584054],
+ [0.186213, 0.018803, 0.587228],
+ [0.193374, 0.018354, 0.590330],
+ [0.200445, 0.017902, 0.593364],
+ [0.207435, 0.017442, 0.596333],
+ [0.214350, 0.016973, 0.599239],
+ [0.221197, 0.016497, 0.602083],
+ [0.227983, 0.016007, 0.604867],
+ [0.234715, 0.015502, 0.607592],
+ [0.241396, 0.014979, 0.610259],
+ [0.248032, 0.014439, 0.612868],
+ [0.254627, 0.013882, 0.615419],
+ [0.261183, 0.013308, 0.617911],
+ [0.267703, 0.012716, 0.620346],
+ [0.274191, 0.012109, 0.622722],
+ [0.280648, 0.011488, 0.625038],
+ [0.287076, 0.010855, 0.627295],
+ [0.293478, 0.010213, 0.629490],
+ [0.299855, 0.009561, 0.631624],
+ [0.306210, 0.008902, 0.633694],
+ [0.312543, 0.008239, 0.635700],
+ [0.318856, 0.007576, 0.637640],
+ [0.325150, 0.006915, 0.639512],
+ [0.331426, 0.006261, 0.641316],
+ [0.337683, 0.005618, 0.643049],
+ [0.343925, 0.004991, 0.644710],
+ [0.350150, 0.004382, 0.646298],
+ [0.356359, 0.003798, 0.647810],
+ [0.362553, 0.003243, 0.649245],
+ [0.368733, 0.002724, 0.650601],
+ [0.374897, 0.002245, 0.651876],
+ [0.381047, 0.001814, 0.653068],
+ [0.387183, 0.001434, 0.654177],
+ [0.393304, 0.001114, 0.655199],
+ [0.399411, 0.000859, 0.656133],
+ [0.405503, 0.000678, 0.656977],
+ [0.411580, 0.000577, 0.657730],
+ [0.417642, 0.000564, 0.658390],
+ [0.423689, 0.000646, 0.658956],
+ [0.429719, 0.000831, 0.659425],
+ [0.435734, 0.001127, 0.659797],
+ [0.441732, 0.001540, 0.660069],
+ [0.447714, 0.002080, 0.660240],
+ [0.453677, 0.002755, 0.660310],
+ [0.459623, 0.003574, 0.660277],
+ [0.465550, 0.004545, 0.660139],
+ [0.471457, 0.005678, 0.659897],
+ [0.477344, 0.006980, 0.659549],
+ [0.483210, 0.008460, 0.659095],
+ [0.489055, 0.010127, 0.658534],
+ [0.494877, 0.011990, 0.657865],
+ [0.500678, 0.014055, 0.657088],
+ [0.506454, 0.016333, 0.656202],
+ [0.512206, 0.018833, 0.655209],
+ [0.517933, 0.021563, 0.654109],
+ [0.523633, 0.024532, 0.652901],
+ [0.529306, 0.027747, 0.651586],
+ [0.534952, 0.031217, 0.650165],
+ [0.540570, 0.034950, 0.648640],
+ [0.546157, 0.038954, 0.647010],
+ [0.551715, 0.043136, 0.645277],
+ [0.557243, 0.047331, 0.643443],
+ [0.562738, 0.051545, 0.641509],
+ [0.568201, 0.055778, 0.639477],
+ [0.573632, 0.060028, 0.637349],
+ [0.579029, 0.064296, 0.635126],
+ [0.584391, 0.068579, 0.632812],
+ [0.589719, 0.072878, 0.630408],
+ [0.595011, 0.077190, 0.627917],
+ [0.600266, 0.081516, 0.625342],
+ [0.605485, 0.085854, 0.622686],
+ [0.610667, 0.090204, 0.619951],
+ [0.615812, 0.094564, 0.617140],
+ [0.620919, 0.098934, 0.614257],
+ [0.625987, 0.103312, 0.611305],
+ [0.631017, 0.107699, 0.608287],
+ [0.636008, 0.112092, 0.605205],
+ [0.640959, 0.116492, 0.602065],
+ [0.645872, 0.120898, 0.598867],
+ [0.650746, 0.125309, 0.595617],
+ [0.655580, 0.129725, 0.592317],
+ [0.660374, 0.134144, 0.588971],
+ [0.665129, 0.138566, 0.585582],
+ [0.669845, 0.142992, 0.582154],
+ [0.674522, 0.147419, 0.578688],
+ [0.679160, 0.151848, 0.575189],
+ [0.683758, 0.156278, 0.571660],
+ [0.688318, 0.160709, 0.568103],
+ [0.692840, 0.165141, 0.564522],
+ [0.697324, 0.169573, 0.560919],
+ [0.701769, 0.174005, 0.557296],
+ [0.706178, 0.178437, 0.553657],
+ [0.710549, 0.182868, 0.550004],
+ [0.714883, 0.187299, 0.546338],
+ [0.719181, 0.191729, 0.542663],
+ [0.723444, 0.196158, 0.538981],
+ [0.727670, 0.200586, 0.535293],
+ [0.731862, 0.205013, 0.531601],
+ [0.736019, 0.209439, 0.527908],
+ [0.740143, 0.213864, 0.524216],
+ [0.744232, 0.218288, 0.520524],
+ [0.748289, 0.222711, 0.516834],
+ [0.752312, 0.227133, 0.513149],
+ [0.756304, 0.231555, 0.509468],
+ [0.760264, 0.235976, 0.505794],
+ [0.764193, 0.240396, 0.502126],
+ [0.768090, 0.244817, 0.498465],
+ [0.771958, 0.249237, 0.494813],
+ [0.775796, 0.253658, 0.491171],
+ [0.779604, 0.258078, 0.487539],
+ [0.783383, 0.262500, 0.483918],
+ [0.787133, 0.266922, 0.480307],
+ [0.790855, 0.271345, 0.476706],
+ [0.794549, 0.275770, 0.473117],
+ [0.798216, 0.280197, 0.469538],
+ [0.801855, 0.284626, 0.465971],
+ [0.805467, 0.289057, 0.462415],
+ [0.809052, 0.293491, 0.458870],
+ [0.812612, 0.297928, 0.455338],
+ [0.816144, 0.302368, 0.451816],
+ [0.819651, 0.306812, 0.448306],
+ [0.823132, 0.311261, 0.444806],
+ [0.826588, 0.315714, 0.441316],
+ [0.830018, 0.320172, 0.437836],
+ [0.833422, 0.324635, 0.434366],
+ [0.836801, 0.329105, 0.430905],
+ [0.840155, 0.333580, 0.427455],
+ [0.843484, 0.338062, 0.424013],
+ [0.846788, 0.342551, 0.420579],
+ [0.850066, 0.347048, 0.417153],
+ [0.853319, 0.351553, 0.413734],
+ [0.856547, 0.356066, 0.410322],
+ [0.859750, 0.360588, 0.406917],
+ [0.862927, 0.365119, 0.403519],
+ [0.866078, 0.369660, 0.400126],
+ [0.869203, 0.374212, 0.396738],
+ [0.872303, 0.378774, 0.393355],
+ [0.875376, 0.383347, 0.389976],
+ [0.878423, 0.387932, 0.386600],
+ [0.881443, 0.392529, 0.383229],
+ [0.884436, 0.397139, 0.379860],
+ [0.887402, 0.401762, 0.376494],
+ [0.890340, 0.406398, 0.373130],
+ [0.893250, 0.411048, 0.369768],
+ [0.896131, 0.415712, 0.366407],
+ [0.898984, 0.420392, 0.363047],
+ [0.901807, 0.425087, 0.359688],
+ [0.904601, 0.429797, 0.356329],
+ [0.907365, 0.434524, 0.352970],
+ [0.910098, 0.439268, 0.349610],
+ [0.912800, 0.444029, 0.346251],
+ [0.915471, 0.448807, 0.342890],
+ [0.918109, 0.453603, 0.339529],
+ [0.920714, 0.458417, 0.336166],
+ [0.923287, 0.463251, 0.332801],
+ [0.925825, 0.468103, 0.329435],
+ [0.928329, 0.472975, 0.326067],
+ [0.930798, 0.477867, 0.322697],
+ [0.933232, 0.482780, 0.319325],
+ [0.935630, 0.487712, 0.315952],
+ [0.937990, 0.492667, 0.312575],
+ [0.940313, 0.497642, 0.309197],
+ [0.942598, 0.502639, 0.305816],
+ [0.944844, 0.507658, 0.302433],
+ [0.947051, 0.512699, 0.299049],
+ [0.949217, 0.517763, 0.295662],
+ [0.951344, 0.522850, 0.292275],
+ [0.953428, 0.527960, 0.288883],
+ [0.955470, 0.533093, 0.285490],
+ [0.957469, 0.538250, 0.282096],
+ [0.959424, 0.543431, 0.278701],
+ [0.961336, 0.548636, 0.275305],
+ [0.963203, 0.553865, 0.271909],
+ [0.965024, 0.559118, 0.268513],
+ [0.966798, 0.564396, 0.265118],
+ [0.968526, 0.569700, 0.261721],
+ [0.970205, 0.575028, 0.258325],
+ [0.971835, 0.580382, 0.254931],
+ [0.973416, 0.585761, 0.251540],
+ [0.974947, 0.591165, 0.248151],
+ [0.976428, 0.596595, 0.244767],
+ [0.977856, 0.602051, 0.241387],
+ [0.979233, 0.607532, 0.238013],
+ [0.980556, 0.613039, 0.234646],
+ [0.981826, 0.618572, 0.231287],
+ [0.983041, 0.624131, 0.227937],
+ [0.984199, 0.629718, 0.224595],
+ [0.985301, 0.635330, 0.221265],
+ [0.986345, 0.640969, 0.217948],
+ [0.987332, 0.646633, 0.214648],
+ [0.988260, 0.652325, 0.211364],
+ [0.989128, 0.658043, 0.208100],
+ [0.989935, 0.663787, 0.204859],
+ [0.990681, 0.669558, 0.201642],
+ [0.991365, 0.675355, 0.198453],
+ [0.991985, 0.681179, 0.195295],
+ [0.992541, 0.687030, 0.192170],
+ [0.993032, 0.692907, 0.189084],
+ [0.993456, 0.698810, 0.186041],
+ [0.993814, 0.704741, 0.183043],
+ [0.994103, 0.710698, 0.180097],
+ [0.994324, 0.716681, 0.177208],
+ [0.994474, 0.722691, 0.174381],
+ [0.994553, 0.728728, 0.171622],
+ [0.994561, 0.734791, 0.168938],
+ [0.994495, 0.740880, 0.166335],
+ [0.994355, 0.746995, 0.163821],
+ [0.994141, 0.753137, 0.161404],
+ [0.993851, 0.759304, 0.159092],
+ [0.993482, 0.765499, 0.156891],
+ [0.993033, 0.771720, 0.154808],
+ [0.992505, 0.777967, 0.152855],
+ [0.991897, 0.784239, 0.151042],
+ [0.991209, 0.790537, 0.149377],
+ [0.990439, 0.796859, 0.147870],
+ [0.989587, 0.803205, 0.146529],
+ [0.988648, 0.809579, 0.145357],
+ [0.987621, 0.815978, 0.144363],
+ [0.986509, 0.822401, 0.143557],
+ [0.985314, 0.828846, 0.142945],
+ [0.984031, 0.835315, 0.142528],
+ [0.982653, 0.841812, 0.142303],
+ [0.981190, 0.848329, 0.142279],
+ [0.979644, 0.854866, 0.142453],
+ [0.977995, 0.861432, 0.142808],
+ [0.976265, 0.868016, 0.143351],
+ [0.974443, 0.874622, 0.144061],
+ [0.972530, 0.881250, 0.144923],
+ [0.970533, 0.887896, 0.145919],
+ [0.968443, 0.894564, 0.147014],
+ [0.966271, 0.901249, 0.148180],
+ [0.964021, 0.907950, 0.149370],
+ [0.961681, 0.914672, 0.150520],
+ [0.959276, 0.921407, 0.151566],
+ [0.956808, 0.928152, 0.152409],
+ [0.954287, 0.934908, 0.152921],
+ [0.951726, 0.941671, 0.152925],
+ [0.949151, 0.948435, 0.152178],
+ [0.946602, 0.955190, 0.150328],
+ [0.944152, 0.961916, 0.146861],
+ [0.941896, 0.968590, 0.140956],
+ [0.940015, 0.975158, 0.131326]]
+
+_viridis_data = [[0.267004, 0.004874, 0.329415],
+ [0.268510, 0.009605, 0.335427],
+ [0.269944, 0.014625, 0.341379],
+ [0.271305, 0.019942, 0.347269],
+ [0.272594, 0.025563, 0.353093],
+ [0.273809, 0.031497, 0.358853],
+ [0.274952, 0.037752, 0.364543],
+ [0.276022, 0.044167, 0.370164],
+ [0.277018, 0.050344, 0.375715],
+ [0.277941, 0.056324, 0.381191],
+ [0.278791, 0.062145, 0.386592],
+ [0.279566, 0.067836, 0.391917],
+ [0.280267, 0.073417, 0.397163],
+ [0.280894, 0.078907, 0.402329],
+ [0.281446, 0.084320, 0.407414],
+ [0.281924, 0.089666, 0.412415],
+ [0.282327, 0.094955, 0.417331],
+ [0.282656, 0.100196, 0.422160],
+ [0.282910, 0.105393, 0.426902],
+ [0.283091, 0.110553, 0.431554],
+ [0.283197, 0.115680, 0.436115],
+ [0.283229, 0.120777, 0.440584],
+ [0.283187, 0.125848, 0.444960],
+ [0.283072, 0.130895, 0.449241],
+ [0.282884, 0.135920, 0.453427],
+ [0.282623, 0.140926, 0.457517],
+ [0.282290, 0.145912, 0.461510],
+ [0.281887, 0.150881, 0.465405],
+ [0.281412, 0.155834, 0.469201],
+ [0.280868, 0.160771, 0.472899],
+ [0.280255, 0.165693, 0.476498],
+ [0.279574, 0.170599, 0.479997],
+ [0.278826, 0.175490, 0.483397],
+ [0.278012, 0.180367, 0.486697],
+ [0.277134, 0.185228, 0.489898],
+ [0.276194, 0.190074, 0.493001],
+ [0.275191, 0.194905, 0.496005],
+ [0.274128, 0.199721, 0.498911],
+ [0.273006, 0.204520, 0.501721],
+ [0.271828, 0.209303, 0.504434],
+ [0.270595, 0.214069, 0.507052],
+ [0.269308, 0.218818, 0.509577],
+ [0.267968, 0.223549, 0.512008],
+ [0.266580, 0.228262, 0.514349],
+ [0.265145, 0.232956, 0.516599],
+ [0.263663, 0.237631, 0.518762],
+ [0.262138, 0.242286, 0.520837],
+ [0.260571, 0.246922, 0.522828],
+ [0.258965, 0.251537, 0.524736],
+ [0.257322, 0.256130, 0.526563],
+ [0.255645, 0.260703, 0.528312],
+ [0.253935, 0.265254, 0.529983],
+ [0.252194, 0.269783, 0.531579],
+ [0.250425, 0.274290, 0.533103],
+ [0.248629, 0.278775, 0.534556],
+ [0.246811, 0.283237, 0.535941],
+ [0.244972, 0.287675, 0.537260],
+ [0.243113, 0.292092, 0.538516],
+ [0.241237, 0.296485, 0.539709],
+ [0.239346, 0.300855, 0.540844],
+ [0.237441, 0.305202, 0.541921],
+ [0.235526, 0.309527, 0.542944],
+ [0.233603, 0.313828, 0.543914],
+ [0.231674, 0.318106, 0.544834],
+ [0.229739, 0.322361, 0.545706],
+ [0.227802, 0.326594, 0.546532],
+ [0.225863, 0.330805, 0.547314],
+ [0.223925, 0.334994, 0.548053],
+ [0.221989, 0.339161, 0.548752],
+ [0.220057, 0.343307, 0.549413],
+ [0.218130, 0.347432, 0.550038],
+ [0.216210, 0.351535, 0.550627],
+ [0.214298, 0.355619, 0.551184],
+ [0.212395, 0.359683, 0.551710],
+ [0.210503, 0.363727, 0.552206],
+ [0.208623, 0.367752, 0.552675],
+ [0.206756, 0.371758, 0.553117],
+ [0.204903, 0.375746, 0.553533],
+ [0.203063, 0.379716, 0.553925],
+ [0.201239, 0.383670, 0.554294],
+ [0.199430, 0.387607, 0.554642],
+ [0.197636, 0.391528, 0.554969],
+ [0.195860, 0.395433, 0.555276],
+ [0.194100, 0.399323, 0.555565],
+ [0.192357, 0.403199, 0.555836],
+ [0.190631, 0.407061, 0.556089],
+ [0.188923, 0.410910, 0.556326],
+ [0.187231, 0.414746, 0.556547],
+ [0.185556, 0.418570, 0.556753],
+ [0.183898, 0.422383, 0.556944],
+ [0.182256, 0.426184, 0.557120],
+ [0.180629, 0.429975, 0.557282],
+ [0.179019, 0.433756, 0.557430],
+ [0.177423, 0.437527, 0.557565],
+ [0.175841, 0.441290, 0.557685],
+ [0.174274, 0.445044, 0.557792],
+ [0.172719, 0.448791, 0.557885],
+ [0.171176, 0.452530, 0.557965],
+ [0.169646, 0.456262, 0.558030],
+ [0.168126, 0.459988, 0.558082],
+ [0.166617, 0.463708, 0.558119],
+ [0.165117, 0.467423, 0.558141],
+ [0.163625, 0.471133, 0.558148],
+ [0.162142, 0.474838, 0.558140],
+ [0.160665, 0.478540, 0.558115],
+ [0.159194, 0.482237, 0.558073],
+ [0.157729, 0.485932, 0.558013],
+ [0.156270, 0.489624, 0.557936],
+ [0.154815, 0.493313, 0.557840],
+ [0.153364, 0.497000, 0.557724],
+ [0.151918, 0.500685, 0.557587],
+ [0.150476, 0.504369, 0.557430],
+ [0.149039, 0.508051, 0.557250],
+ [0.147607, 0.511733, 0.557049],
+ [0.146180, 0.515413, 0.556823],
+ [0.144759, 0.519093, 0.556572],
+ [0.143343, 0.522773, 0.556295],
+ [0.141935, 0.526453, 0.555991],
+ [0.140536, 0.530132, 0.555659],
+ [0.139147, 0.533812, 0.555298],
+ [0.137770, 0.537492, 0.554906],
+ [0.136408, 0.541173, 0.554483],
+ [0.135066, 0.544853, 0.554029],
+ [0.133743, 0.548535, 0.553541],
+ [0.132444, 0.552216, 0.553018],
+ [0.131172, 0.555899, 0.552459],
+ [0.129933, 0.559582, 0.551864],
+ [0.128729, 0.563265, 0.551229],
+ [0.127568, 0.566949, 0.550556],
+ [0.126453, 0.570633, 0.549841],
+ [0.125394, 0.574318, 0.549086],
+ [0.124395, 0.578002, 0.548287],
+ [0.123463, 0.581687, 0.547445],
+ [0.122606, 0.585371, 0.546557],
+ [0.121831, 0.589055, 0.545623],
+ [0.121148, 0.592739, 0.544641],
+ [0.120565, 0.596422, 0.543611],
+ [0.120092, 0.600104, 0.542530],
+ [0.119738, 0.603785, 0.541400],
+ [0.119512, 0.607464, 0.540218],
+ [0.119423, 0.611141, 0.538982],
+ [0.119483, 0.614817, 0.537692],
+ [0.119699, 0.618490, 0.536347],
+ [0.120081, 0.622161, 0.534946],
+ [0.120638, 0.625828, 0.533488],
+ [0.121380, 0.629492, 0.531973],
+ [0.122312, 0.633153, 0.530398],
+ [0.123444, 0.636809, 0.528763],
+ [0.124780, 0.640461, 0.527068],
+ [0.126326, 0.644107, 0.525311],
+ [0.128087, 0.647749, 0.523491],
+ [0.130067, 0.651384, 0.521608],
+ [0.132268, 0.655014, 0.519661],
+ [0.134692, 0.658636, 0.517649],
+ [0.137339, 0.662252, 0.515571],
+ [0.140210, 0.665859, 0.513427],
+ [0.143303, 0.669459, 0.511215],
+ [0.146616, 0.673050, 0.508936],
+ [0.150148, 0.676631, 0.506589],
+ [0.153894, 0.680203, 0.504172],
+ [0.157851, 0.683765, 0.501686],
+ [0.162016, 0.687316, 0.499129],
+ [0.166383, 0.690856, 0.496502],
+ [0.170948, 0.694384, 0.493803],
+ [0.175707, 0.697900, 0.491033],
+ [0.180653, 0.701402, 0.488189],
+ [0.185783, 0.704891, 0.485273],
+ [0.191090, 0.708366, 0.482284],
+ [0.196571, 0.711827, 0.479221],
+ [0.202219, 0.715272, 0.476084],
+ [0.208030, 0.718701, 0.472873],
+ [0.214000, 0.722114, 0.469588],
+ [0.220124, 0.725509, 0.466226],
+ [0.226397, 0.728888, 0.462789],
+ [0.232815, 0.732247, 0.459277],
+ [0.239374, 0.735588, 0.455688],
+ [0.246070, 0.738910, 0.452024],
+ [0.252899, 0.742211, 0.448284],
+ [0.259857, 0.745492, 0.444467],
+ [0.266941, 0.748751, 0.440573],
+ [0.274149, 0.751988, 0.436601],
+ [0.281477, 0.755203, 0.432552],
+ [0.288921, 0.758394, 0.428426],
+ [0.296479, 0.761561, 0.424223],
+ [0.304148, 0.764704, 0.419943],
+ [0.311925, 0.767822, 0.415586],
+ [0.319809, 0.770914, 0.411152],
+ [0.327796, 0.773980, 0.406640],
+ [0.335885, 0.777018, 0.402049],
+ [0.344074, 0.780029, 0.397381],
+ [0.352360, 0.783011, 0.392636],
+ [0.360741, 0.785964, 0.387814],
+ [0.369214, 0.788888, 0.382914],
+ [0.377779, 0.791781, 0.377939],
+ [0.386433, 0.794644, 0.372886],
+ [0.395174, 0.797475, 0.367757],
+ [0.404001, 0.800275, 0.362552],
+ [0.412913, 0.803041, 0.357269],
+ [0.421908, 0.805774, 0.351910],
+ [0.430983, 0.808473, 0.346476],
+ [0.440137, 0.811138, 0.340967],
+ [0.449368, 0.813768, 0.335384],
+ [0.458674, 0.816363, 0.329727],
+ [0.468053, 0.818921, 0.323998],
+ [0.477504, 0.821444, 0.318195],
+ [0.487026, 0.823929, 0.312321],
+ [0.496615, 0.826376, 0.306377],
+ [0.506271, 0.828786, 0.300362],
+ [0.515992, 0.831158, 0.294279],
+ [0.525776, 0.833491, 0.288127],
+ [0.535621, 0.835785, 0.281908],
+ [0.545524, 0.838039, 0.275626],
+ [0.555484, 0.840254, 0.269281],
+ [0.565498, 0.842430, 0.262877],
+ [0.575563, 0.844566, 0.256415],
+ [0.585678, 0.846661, 0.249897],
+ [0.595839, 0.848717, 0.243329],
+ [0.606045, 0.850733, 0.236712],
+ [0.616293, 0.852709, 0.230052],
+ [0.626579, 0.854645, 0.223353],
+ [0.636902, 0.856542, 0.216620],
+ [0.647257, 0.858400, 0.209861],
+ [0.657642, 0.860219, 0.203082],
+ [0.668054, 0.861999, 0.196293],
+ [0.678489, 0.863742, 0.189503],
+ [0.688944, 0.865448, 0.182725],
+ [0.699415, 0.867117, 0.175971],
+ [0.709898, 0.868751, 0.169257],
+ [0.720391, 0.870350, 0.162603],
+ [0.730889, 0.871916, 0.156029],
+ [0.741388, 0.873449, 0.149561],
+ [0.751884, 0.874951, 0.143228],
+ [0.762373, 0.876424, 0.137064],
+ [0.772852, 0.877868, 0.131109],
+ [0.783315, 0.879285, 0.125405],
+ [0.793760, 0.880678, 0.120005],
+ [0.804182, 0.882046, 0.114965],
+ [0.814576, 0.883393, 0.110347],
+ [0.824940, 0.884720, 0.106217],
+ [0.835270, 0.886029, 0.102646],
+ [0.845561, 0.887322, 0.099702],
+ [0.855810, 0.888601, 0.097452],
+ [0.866013, 0.889868, 0.095953],
+ [0.876168, 0.891125, 0.095250],
+ [0.886271, 0.892374, 0.095374],
+ [0.896320, 0.893616, 0.096335],
+ [0.906311, 0.894855, 0.098125],
+ [0.916242, 0.896091, 0.100717],
+ [0.926106, 0.897330, 0.104071],
+ [0.935904, 0.898570, 0.108131],
+ [0.945636, 0.899815, 0.112838],
+ [0.955300, 0.901065, 0.118128],
+ [0.964894, 0.902323, 0.123941],
+ [0.974417, 0.903590, 0.130215],
+ [0.983868, 0.904867, 0.136897],
+ [0.993248, 0.906157, 0.143936]]
+
+
+cmaps = {}
+for (name, data) in (('magma', _magma_data),
+ ('inferno', _inferno_data),
+ ('plasma', _plasma_data),
+ ('viridis', _viridis_data)):
+
+ cmaps[name] = ListedColormap(data, name=name)
+
+magma = cmaps['magma']
+inferno = cmaps['inferno']
+plasma = cmaps['plasma']
+viridis = cmaps['viridis']
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
new file mode 100644
index 0000000..6407d44
--- /dev/null
+++ b/silx/gui/plot/MaskToolsWidget.py
@@ -0,0 +1,615 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
+
+- :class:`ImageMask`: Handle mask bitmap update and history
+- :class:`MaskToolsWidget`: GUI for :class:`Mask`
+- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+from __future__ import division
+
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/04/2017"
+
+
+import os
+import sys
+import numpy
+import logging
+
+from silx.image import shapes
+
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from . import items
+from .Colors import cursorColorForColormap, rgba
+from .. import qt
+
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+
+try:
+ import fabio
+except ImportError:
+ fabio = None
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ImageMask(BaseMask):
+ """A 2D mask field with update operations.
+
+ Coords follows (row, column) convention and are in mask array coords.
+
+ This is meant for internal use by :class:`MaskToolsWidget`.
+ """
+ def __init__(self, image=None):
+ """
+
+ :param image: :class:`silx.gui.plot.items.ImageBase` instance
+ """
+ BaseMask.__init__(self, image)
+
+ def getDataValues(self):
+ """Return image data as a 2D or 3D array (if it is a RGBA image).
+
+ :rtype: 2D or 3D numpy.ndarray
+ """
+ return self._dataItem.getData(copy=False)
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy',
+ or 'msk' (if FabIO is installed)
+ :raise Exception: Raised if the file writing fail
+ """
+ if kind == 'edf':
+ edfFile = EdfFile(filename, access="w+")
+ edfFile.WriteImage({}, self.getMask(copy=False), Append=0)
+
+ elif kind == 'tif':
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(self.getMask(copy=False), software='silx')
+
+ elif kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ elif kind == 'msk':
+ if fabio is None:
+ raise ImportError("Fit2d mask files can't be written: Fabio module is not available")
+ try:
+ data = self.getMask(copy=False)
+ image = fabio.fabioimage.FabioImage(data=data)
+ image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage)
+ image.save(filename)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("Mask file can't be written")
+
+ else:
+ raise ValueError("Format '%s' is not supported" % kind)
+
+ # Drawing operations
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask a rectangle of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row: Starting row of the rectangle
+ :param int col: Starting column of the rectangle
+ :param int height:
+ :param int width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ assert 0 < level < 256
+ selection = self._mask[max(0, row):row + height + 1,
+ max(0, col):col + width + 1]
+ if mask:
+ selection[:, :] = level
+ else:
+ selection[selection == level] = 0
+ self._notify()
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ fill = shapes.polygon_fill_mask(vertices, self._mask.shape)
+ if mask:
+ self._mask[fill != 0] = level
+ else:
+ self._mask[numpy.logical_and(fill != 0,
+ self._mask == level)] = 0
+ self._notify()
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ valid = numpy.logical_and(
+ numpy.logical_and(rows >= 0, cols >= 0),
+ numpy.logical_and(rows < self._mask.shape[0],
+ cols < self._mask.shape[1]))
+ rows, cols = rows[valid], cols[valid]
+
+ if mask:
+ self._mask[rows, cols] = level
+ else:
+ inMask = self._mask[rows, cols] == level
+ self._mask[rows[inMask], cols[inMask]] = 0
+ self._notify()
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Disk center row.
+ :param int ccol: Disk center column.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.circle_fill(crow, ccol, radius)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row0: Row of the starting point.
+ :param int col0: Column of the starting point.
+ :param int row1: Row of the end point.
+ :param int col1: Column of the end point.
+ :param int width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.draw_line(row0, col0, row1, col1, width)
+ self.updatePoints(level, rows, cols, mask)
+
+
+class MaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for drawing mask on an image in a PlotWidget."""
+
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None):
+ self._origin = (0., 0.) # Mask origin in plot
+ self._scale = (1., 1.) # Mask scale in plot
+ self._z = 1 # Mask layer in plot
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
+
+ self._mask = ImageMask()
+
+ super(MaskToolsWidget, self).__init__(parent, plot)
+
+ self._initWidgets()
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+ if len(mask.shape) != 2:
+ _logger.error('Not an image, shape: %d', len(mask.shape))
+ return None
+
+ if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ _logger.warning('Mask has not the same size as current image.'
+ ' Mask will be cropped or padded to fit image'
+ ' dimensions. %s != %s',
+ str(mask.shape), str(self._data.shape))
+ resizedMask = numpy.zeros(self._data.shape[0:2],
+ dtype=numpy.uint8)
+ height = min(self._data.shape[0], mask.shape[0])
+ width = min(self._data.shape[1], mask.shape[1])
+ resizedMask[:height, :width] = mask[:height, :width]
+ self._mask.setMask(resizedMask, copy=False)
+ self._mask.commit()
+ return resizedMask.shape
+
+ # Handle mask refresh on the plot
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if len(mask):
+ self.plot.addImage(mask, legend=self._maskName,
+ colormap=self._colormap,
+ origin=self._origin,
+ scale=self._scale,
+ z=self._z,
+ replace=False, resetzoom=False)
+ elif self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+ self._activeImageChanged() # Init mask + enable/disable widget
+ self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def hideEvent(self, event):
+ self.plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ if not self.browseAction.isChecked():
+ self.browseAction.trigger() # Disable drawing tool
+
+ if len(self.getSelectionMask(copy=False)):
+ self.plot.sigActiveImageChanged.connect(
+ self._activeImageChangedAfterCare)
+
+ def _setOverlayColorForImage(self, image):
+ """Set the color of overlay adapted to image
+
+ :param image: :class:`.items.ImageBase` object to set color for.
+ """
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ self._defaultOverlayColor = rgba(
+ cursorColorForColormap(colormap['name']))
+ else:
+ self._defaultOverlayColor = rgba('black')
+
+ def _activeImageChangedAfterCare(self, *args):
+ """Check synchro of active image and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to origin, scale and z.
+ """
+ activeImage = self.plot.getActiveImage()
+ if activeImage is None or activeImage.getLegend() == self._maskName:
+ # No active image or active image is the mask...
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ self._setOverlayColorForImage(activeImage)
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = activeImage.getOrigin()
+ self._scale = activeImage.getScale()
+ self._z = activeImage.getZValue() + 1
+ self._data = activeImage.getData(copy=False)
+ if self._data.shape[:2] != self.getSelectionMask(copy=False).shape:
+ # Image has not the same size, remove mask and stop listening
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ # Refresh in case origin, scale, z changed
+ self._mask.setDataItem(activeImage)
+ self._updatePlotMask()
+
+ def _activeImageChanged(self, *args):
+ """Update widget and mask according to active image changes"""
+ activeImage = self.plot.getActiveImage()
+ if activeImage is None or activeImage.getLegend() == self._maskName:
+ # No active image or active image is the mask...
+ self.setEnabled(False)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.reset()
+ self._mask.commit()
+
+ else: # There is an active image
+ self.setEnabled(True)
+
+ self._setOverlayColorForImage(activeImage)
+
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = activeImage.getOrigin()
+ self._scale = activeImage.getScale()
+ self._z = activeImage.getZValue() + 1
+ self._data = activeImage.getData(copy=False)
+ self._mask.setDataItem(activeImage)
+ if self._data.shape[:2] != self.getSelectionMask(copy=False).shape:
+ self._mask.reset(self._data.shape[:2])
+ self._mask.commit()
+ else:
+ # Refresh in case origin, scale, z changed
+ self._updatePlotMask()
+
+ # Threshold tools only available for data with colormap
+ self.thresholdGroup.setEnabled(self._data.ndim == 2)
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
+ elif extension == "edf":
+ try:
+ mask = EdfFile(filename, access='r').GetData(0)
+ except Exception as e:
+ _logger.error("Can't load filename %s", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif extension == "msk":
+ if fabio is None:
+ raise ImportError("Fit2d mask files can't be read: Fabio module is not available")
+ try:
+ mask = fabio.open(filename).data
+ except Exception as e:
+ _logger.error("Can't load fit2d mask file")
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ effectiveMaskShape = self.setSelectionMask(mask, copy=False)
+ if effectiveMaskShape is None:
+ return
+ if mask.shape != effectiveMaskShape:
+ msg = 'Mask was resized from %s to %s'
+ msg = msg % (str(mask.shape), str(effectiveMaskShape))
+ raise RuntimeWarning(msg)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+ filters = [
+ 'EDF (*.edf)',
+ 'TIFF (*.tif)',
+ 'NumPy binary file (*.npy)',
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ 'Fit2D mask (*.msk)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.maskFileDir = os.path.dirname(filename)
+ try:
+ self.load(filename)
+ except RuntimeWarning as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Warning)
+ msg.setText("Mask loaded but an operation was applied.\n" + message)
+ msg.exec_()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec_()
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setModal(1)
+ filters = [
+ 'EDF (*.edf)',
+ 'TIFF (*.tif)',
+ 'NumPy binary file (*.npy)',
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ 'Fit2D mask (*.msk)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ # convert filter name to extension name with the .
+ extension = dialog.selectedNameFilter().split()[-1][2:-1]
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename):
+ try:
+ os.remove(filename)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec_()
+ return
+
+ self.maskFileDir = os.path.dirname(filename)
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot save file %s\n%s" % (filename, e.args[0]))
+ msg.exec_()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(shape=self._data.shape[:2])
+ self._mask.commit()
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if (self._drawingMode == 'rectangle' and
+ event['event'] == 'drawingFinished'):
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event['height'] / sy))
+ width = int(abs(event['width'] / sx))
+
+ row = int((event['y'] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event['x'] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level,
+ row=row,
+ col=col,
+ height=height,
+ width=width,
+ mask=doMask)
+ self._mask.commit()
+
+ elif (self._drawingMode == 'polygon' and
+ event['event'] == 'drawingFinished'):
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event['points'] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ col, row = (event['points'][-1] - self._origin) / self._scale
+ col, row = int(col), int(row)
+ brushSize = self.pencilSpinBox.value()
+
+ if self._lastPencilPos != (row, col):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ row, col,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, row, col, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = row, col
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active image colormap range"""
+ activeImage = self.plot.getActiveImage()
+ if (isinstance(activeImage, items.ColormapMixIn) and
+ activeImage.getLegend() != self._maskName):
+ # Update thresholds according to colormap
+ colormap = activeImage.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(activeImage.getData(copy=False))
+ max_ = numpy.nanmax(activeImage.getData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class MaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`MaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ super(MaskToolsDockWidget, self).__init__(parent, name)
+ self.setWidget(MaskToolsWidget(plot=plot))
+ self.widget().sigMaskChanged.connect(self._emitSigMaskChanged)
diff --git a/silx/gui/plot/Plot.py b/silx/gui/plot/Plot.py
new file mode 100644
index 0000000..fe0a7b8
--- /dev/null
+++ b/silx/gui/plot/Plot.py
@@ -0,0 +1,2925 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Plot API for 1D and 2D data.
+
+The :class:`Plot` implements the plot API initially provided in PyMca.
+
+
+Colormap
+--------
+
+The :class:`Plot` uses a dictionary to describe a colormap.
+This dictionary has the following keys:
+
+- 'name': str, name of the colormap. Available colormap are returned by
+ :meth:`Plot.getSupportedColormaps`.
+ At least 'gray', 'reversed gray', 'temperature',
+ 'red', 'green', 'blue' are supported.
+- 'normalization': Either 'linear' or 'log'
+- 'autoscale': bool, True to get bounds from the min and max of the
+ data, False to use [vmin, vmax]
+- 'vmin': float, min value, ignored if autoscale is True
+- 'vmax': float, max value, ignored if autoscale is True
+- 'colors': optional, custom colormap.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ If 'name' is None, then this array is used as the colormap.
+
+
+Plot Events
+-----------
+
+The Plot sends some event to the registered callback
+(See :meth:`Plot.setCallback`).
+Those events are sent as a dictionary with a key 'event' describing the kind
+of event.
+
+Drawing events
+..............
+
+'drawingProgress' and 'drawingFinished' events are sent during drawing
+interaction (See :meth:`Plot.setInteractiveMode`).
+
+- 'event': 'drawingProgress' or 'drawingFinished'
+- 'parameters': dict of parameters used by the drawing mode.
+ It has the following keys: 'shape', 'label', 'color'.
+ See :meth:`Plot.setInteractiveMode`.
+- 'points': Points (x, y) in data coordinates of the drawn shape.
+ For 'hline' and 'vline', it is the 2 points defining the line.
+ For 'line' and 'rectangle', it is the coordinates of the start
+ drawing point and the latest drawing point.
+ For 'polygon', it is the coordinates of all points of the shape.
+- 'type': The type of drawing in 'line', 'hline', 'polygon', 'rectangle',
+ 'vline'.
+- 'xdata' and 'ydata': X coords and Y coords of shape points in data
+ coordinates (as in 'points').
+
+When the type is 'rectangle', the following additional keys are provided:
+
+- 'x' and 'y': The origin of the rectangle in data coordinates
+- 'widht' and 'height': The size of the rectangle in data coordinates
+
+
+Mouse events
+............
+
+'mouseMoved', 'mouseClicked' and 'mouseDoubleClicked' events are sent for
+mouse events.
+
+They provide the following keys:
+
+- 'event': 'mouseMoved', 'mouseClicked' or 'mouseDoubleClicked'
+- 'button': the mouse button that was pressed in 'left', 'middle', 'right'
+- 'x' and 'y': The mouse position in data coordinates
+- 'xpixel' and 'ypixel': The mouse position in pixels
+
+
+Marker events
+.............
+
+'hover', 'markerClicked', 'markerMoving' and 'markerMoved' events are
+sent during interaction with markers.
+
+'hover' is sent when the mouse cursor is over a marker.
+'markerClicker' is sent when the user click on a selectable marker.
+'markerMoving' and 'markerMoved' are sent when a draggable marker is moved.
+
+They provide the following keys:
+
+- 'event': 'hover', 'markerClicked', 'markerMoving' or 'markerMoved'
+- 'button': the mouse button that is pressed in 'left', 'middle', 'right'
+- 'draggable': True if the marker is draggable, False otherwise
+- 'label': The legend associated with the clicked image or curve
+- 'selectable': True if the marker is selectable, False otherwise
+- 'type': 'marker'
+- 'x' and 'y': The mouse position in data coordinates
+- 'xdata' and 'ydata': The marker position in data coordinates
+
+'markerClicked' and 'markerMoving' events have a 'xpixel' and a 'ypixel'
+additional keys, that provide the mouse position in pixels.
+
+
+Image and curve events
+......................
+
+'curveClicked' and 'imageClicked' events are sent when a selectable curve
+or image is clicked.
+
+Both share the following keys:
+
+- 'event': 'curveClicked' or 'imageClicked'
+- 'button': the mouse button that was pressed in 'left', 'middle', 'right'
+- 'label': The legend associated with the clicked image or curve
+- 'type': The type of item in 'curve', 'image'
+- 'x' and 'y': The clicked position in data coordinates
+- 'xpixel' and 'ypixel': The clicked position in pixels
+
+'curveClicked' events have a 'xdata' and a 'ydata' additional keys, that
+provide the coordinates of the picked points of the curve.
+There can be more than one point of the curve being picked, and if a line of
+the curve is picked, only the first point of the line is included in the list.
+
+'imageClicked' have a 'col' and a 'row' additional keys, that provide
+the column and row index in the image array that was clicked.
+
+
+Limits changed events
+.....................
+
+'limitsChanged' events are sent when the limits of the plot are changed.
+This can results from user interaction or API calls.
+
+It provides the following keys:
+
+- 'event': 'limitsChanged'
+- 'source': id of the widget that emitted this event.
+- 'xdata': Range of X in graph coordinates: (xMin, xMax).
+- 'ydata': Range of Y in graph coordinates: (yMin, yMax).
+- 'y2data': Range of right axis in graph coordinates (y2Min, y2Max) or None.
+
+Plot state change events
+........................
+
+The following events are emitted when the plot is modified.
+They provide the new state:
+
+- 'setGraphCursor' event with a 'state' key (bool)
+- 'setGraphGrid' event with a 'which' key (str), see :meth:`setGraphGrid`
+- 'setKeepDataAspectRatio' event with a 'state' key (bool)
+- 'setXAxisAutoScale' event with a 'state' key (bool)
+- 'setXAxisLogarithmic' event with a 'state' key (bool)
+- 'setYAxisAutoScale' event with a 'state' key (bool)
+- 'setYAxisInverted' event with a 'state' key (bool)
+- 'setYAxisLogarithmic' event with a 'state' key (bool)
+
+A 'contentChanged' event is triggered when the content of the plot is updated.
+It provides the following keys:
+
+- 'action': The change of the plot: 'add' or 'remove'
+- 'kind': The kind of primitive changed: 'curve', 'image', 'item' or 'marker'
+- 'legend': The legend of the primitive changed.
+
+'activeCurveChanged' and 'activeImageChanged' events with the following keys:
+
+- 'legend': Name (str) of the current active item or None if no active item.
+- 'previous': Name (str) of the previous active item or None if no item was
+ active. It is the same as 'legend' if 'updated' == True
+- 'updated': (bool) True if active item name did not changed,
+ but active item data or style was updated.
+
+'interactiveModeChanged' event with a 'source' key identifying the object
+setting the interactive mode.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/02/2017"
+
+
+from collections import OrderedDict, namedtuple
+import itertools
+import logging
+
+import numpy
+
+# Import matplotlib backend here to init matplotlib our way
+from .backends.BackendMatplotlib import BackendMatplotlibQt
+
+try:
+ from matplotlib import cm as matplotlib_cm
+except ImportError:
+ matplotlib_cm = None
+
+from . import Colors
+from . import PlotInteraction
+from . import PlotEvents
+from . import _utils
+
+from . import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+_COLORDICT = Colors.COLORDICT
+_COLORLIST = [_COLORDICT['black'],
+ _COLORDICT['blue'],
+ _COLORDICT['red'],
+ _COLORDICT['green'],
+ _COLORDICT['pink'],
+ _COLORDICT['yellow'],
+ _COLORDICT['brown'],
+ _COLORDICT['cyan'],
+ _COLORDICT['magenta'],
+ _COLORDICT['orange'],
+ _COLORDICT['violet'],
+ # _COLORDICT['bluegreen'],
+ _COLORDICT['grey'],
+ _COLORDICT['darkBlue'],
+ _COLORDICT['darkRed'],
+ _COLORDICT['darkGreen'],
+ _COLORDICT['darkCyan'],
+ _COLORDICT['darkMagenta'],
+ _COLORDICT['darkYellow'],
+ _COLORDICT['darkBrown']]
+
+
+"""
+Object returned when requesting the data range.
+"""
+_PlotDataRange = namedtuple('PlotDataRange',
+ ['x', 'y', 'yright'])
+
+
+class Plot(object):
+ """This class implements the plot API initially provided in PyMca.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param parent: The parent widget of the plot (Default: None)
+ :param backend: The backend to use. A str in:
+ 'matplotlib', 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ """
+
+ DEFAULT_BACKEND = 'matplotlib'
+ """Class attribute setting the default backend for all instances."""
+
+ colorList = _COLORLIST
+ colorDict = _COLORDICT
+
+ def __init__(self, parent=None, backend=None):
+ self._autoreplot = False
+ self._dirty = False
+ self._cursorInPlot = False
+
+ if backend is None:
+ backend = self.DEFAULT_BACKEND
+
+ if hasattr(backend, "__call__"):
+ self._backend = backend(self, parent)
+
+ elif hasattr(backend, "lower"):
+ lowerCaseString = backend.lower()
+ if lowerCaseString in ("matplotlib", "mpl"):
+ backendClass = BackendMatplotlibQt
+ elif lowerCaseString in ('gl', 'opengl'):
+ from .backends.BackendOpenGL import BackendOpenGL
+ backendClass = BackendOpenGL
+ elif lowerCaseString == 'none':
+ from .backends.BackendBase import BackendBase as backendClass
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+ self._backend = backendClass(self, parent)
+
+ else:
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ super(Plot, self).__init__()
+
+ self.setCallback() # set _callback
+
+ # Items handling
+ self._content = OrderedDict()
+ self._contentToUpdate = set()
+
+ self._dataRange = None
+
+ # line types
+ self._styleList = ['-', '--', '-.', ':']
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ self._activeCurveHandling = True
+ self._activeCurveColor = "#000000"
+ self._activeLegend = {'curve': None, 'image': None,
+ 'scatter': None}
+
+ # default properties
+ self._cursorConfiguration = None
+
+ self._logY = False
+ self._logX = False
+ self._xAutoScale = True
+ self._yAutoScale = True
+ self._grid = None
+
+ # Store default labels provided to setGraph[X|Y]Label
+ self._defaultLabels = {'x': '', 'y': '', 'yright': ''}
+ # Store currently displayed labels
+ # Current label can differ from input one with active curve handling
+ self._currentLabels = {'x': '', 'y': '', 'yright': ''}
+
+ self._graphTitle = ''
+
+ self.setGraphTitle()
+ self.setGraphXLabel()
+ self.setGraphYLabel()
+ self.setGraphYLabel('', axis='right')
+
+ self.setDefaultColormap() # Init default colormap
+
+ self.setDefaultPlotPoints(False)
+ self.setDefaultPlotLines(True)
+
+ self._eventHandler = PlotInteraction.PlotInteraction(self)
+ self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.))
+
+ self._pressedButtons = [] # Currently pressed mouse buttons
+
+ self._defaultDataMargins = (0., 0., 0., 0.)
+
+ # Only activate autoreplot at the end
+ # This avoids errors when loaded in Qt designer
+ self._dirty = False
+ self._autoreplot = True
+
+ def _getDirtyPlot(self):
+ """Return the plot dirty flag.
+
+ If False, the plot has not changed since last replot.
+ If True, the full plot need to be redrawn.
+ If 'overlay', only the overlay has changed since last replot.
+
+ It can be accessed by backend to check the dirty state.
+
+ :return: False, True, 'overlay'
+ """
+ return self._dirty
+
+ def _setDirtyPlot(self, overlayOnly=False):
+ """Mark the plot as needing redraw
+
+ :param bool overlayOnly: True to redraw only the overlay,
+ False to redraw everything
+ """
+ wasDirty = self._dirty
+
+ if not self._dirty and overlayOnly:
+ self._dirty = 'overlay'
+ else:
+ self._dirty = True
+
+ if self._autoreplot and not wasDirty:
+ self._backend.postRedisplay()
+
+ def _invalidateDataRange(self):
+ """
+ Notifies this Plot instance that the range has changed and will have
+ to be recomputed.
+ """
+ self._dataRange = None
+
+ def _updateDataRange(self):
+ """
+ Recomputes the range of the data displayed on this Plot.
+ """
+ xMin = yMinLeft = yMinRight = float('nan')
+ xMax = yMaxLeft = yMaxRight = float('nan')
+
+ for item in self._content.values():
+ if item.isVisible():
+ bounds = item.getBounds()
+ if bounds is not None:
+ xMin = numpy.nanmin([xMin, bounds[0]])
+ xMax = numpy.nanmax([xMax, bounds[1]])
+ # Take care of right axis
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ yMinRight = numpy.nanmin([yMinRight, bounds[2]])
+ yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
+ else:
+ yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
+ yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
+
+ def lGetRange(x, y):
+ return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
+ xRange = lGetRange(xMin, xMax)
+ yLeftRange = lGetRange(yMinLeft, yMaxLeft)
+ yRightRange = lGetRange(yMinRight, yMaxRight)
+
+ self._dataRange = _PlotDataRange(x=xRange,
+ y=yLeftRange,
+ yright=yRightRange)
+
+ def getDataRange(self):
+ """
+ Returns this Plot's data range.
+
+ :return: a namedtuple with the following members :
+ x, y (left y axis), yright. Each member is a tuple (min, max)
+ or None if no data is associated with the axis.
+ :rtype: namedtuple
+ """
+ if self._dataRange is None:
+ self._updateDataRange()
+ return self._dataRange
+
+ # Content management
+
+ @staticmethod
+ def _itemKey(item):
+ """Build the key of given :class:`Item` in the plot
+
+ :param Item item: The item to make the key from
+ :return: (legend, kind)
+ :rtype: (str, str)
+ """
+ if isinstance(item, items.Curve):
+ kind = 'curve'
+ elif isinstance(item, items.ImageBase):
+ kind = 'image'
+ elif isinstance(item, items.Scatter):
+ kind = 'scatter'
+ elif isinstance(item, (items.Marker,
+ items.XMarker, items.YMarker)):
+ kind = 'marker'
+ elif isinstance(item, items.Shape):
+ kind = 'item'
+ elif isinstance(item, items.Histogram):
+ kind = 'histogram'
+ else:
+ raise ValueError('Unsupported item type %s' % type(item))
+
+ return item.getLegend(), kind
+
+ def _add(self, item):
+ """Add the given :class:`Item` to the plot.
+
+ :param Item item: The item to append to the plot content
+ """
+ key = self._itemKey(item)
+ if key in self._content:
+ raise RuntimeError('Item already in the plot')
+
+ # Add item to plot
+ self._content[key] = item
+ item._setPlot(self)
+ if item.isVisible():
+ self._itemRequiresUpdate(item)
+ if isinstance(item, (items.Curve, items.ImageBase)):
+ self._invalidateDataRange() # TODO handle this automatically
+
+ def _remove(self, item):
+ """Remove the given :class:`Item` from the plot.
+
+ :param Item item: The item to remove from the plot content
+ """
+ key = self._itemKey(item)
+ if key not in self._content:
+ raise RuntimeError('Item not in the plot')
+
+ # Remove item from plot
+ self._content.pop(key)
+ self._contentToUpdate.discard(item)
+ if item.isVisible():
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+ if item.getBounds() is not None:
+ self._invalidateDataRange()
+ item._removeBackendRenderer(self._backend)
+ item._setPlot(None)
+
+ def _itemRequiresUpdate(self, item):
+ """Called by items in the plot for asynchronous update
+
+ :param Item item: The item that required update
+ """
+ assert item.getPlot() == self
+ self._contentToUpdate.add(item)
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+
+ # Add
+
+ # add * input arguments management:
+ # If an arg is set, then use it.
+ # Else:
+ # If a curve with the same legend exists, then use its arg value
+ # Else, use a default value.
+ # Store used value.
+ # This value is used when curve is updated either internally or by user.
+
+ def addCurve(self, x, y, legend=None, info=None,
+ replace=False, replot=None,
+ color=None, symbol=None,
+ linewidth=None, linestyle=None,
+ xlabel=None, ylabel=None, yaxis=None,
+ xerror=None, yerror=None, z=None, selectable=None,
+ fill=None, resetzoom=True,
+ histogram=None, copy=True, **kw):
+ """Add a 1D curve given by x an y to the graph.
+
+ Curves are uniquely identified by their legend.
+ To add multiple curves, call :meth:`addCurve` multiple times with
+ different legend argument.
+ To replace an existing curve, call :meth:`addCurve` with the
+ existing curve legend.
+ If you want to display the curve values as an histogram see the
+ histogram parameter or :meth:`addHistogram`.
+
+ When curve parameters are not provided, if a curve with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ If you attempt to plot an histogram you can set edges values in x.
+ In this case len(x) = len(y) + 1
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param str legend: The legend to be associated to the curve (or None)
+ :param info: User-defined information associated to the curve
+ :param bool replace: True (the default) to delete already existing
+ curves
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param float linewidth: The width of the curve in pixels (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - None (the default) to use default line style
+
+ :param str xlabel: Label to show on the X axis when the curve is active
+ or None to keep default axis label.
+ :param str ylabel: Label to show on the Y axis when the curve is active
+ or None to keep default axis label.
+ :param str yaxis: The Y axis this curve is attached to.
+ Either 'left' (the default) or 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the curve (default: 1)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the curve can be selected.
+ (Default: True)
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param str histogram: if not None then the curve will be draw as an
+ histogram. The step for each values of the curve can be set to the
+ left, center or right of the original x curve values.
+ If histogram is not None and len(x) == len(y)+1 then x is directly
+ take as edges of the histogram.
+ Type of histogram::
+
+ - None (default)
+ - 'left'
+ - 'right'
+ - 'center'
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this curve
+ """
+ # Deprecation warnings
+ if replot is not None:
+ _logger.warning(
+ 'addCurve deprecated replot argument, use resetzoom instead')
+ resetzoom = replot and resetzoom
+
+ if kw:
+ _logger.warning('addCurve: deprecated extra arguments')
+
+ # This is an histogram, use addHistogram
+ if histogram is not None:
+ histoLegend = self.addHistogram(histogram=y,
+ edges=x,
+ legend=legend,
+ color=color,
+ fill=fill,
+ align=histogram,
+ copy=copy)
+ histo = self.getHistogram(histoLegend)
+
+ histo.setInfo(info)
+ if linewidth is not None:
+ histo.setLineWidth(linewidth)
+ if linestyle is not None:
+ histo.setLineStyle(linestyle)
+ if xlabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support xlabel argument')
+ if ylabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support ylabel argument')
+ if yaxis is not None:
+ histo.setYAxis(yaxis)
+ if z is not None:
+ histo.setZValue(z)
+ if selectable is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support selectable argument')
+
+ return
+
+ legend = 'Unnamed curve 1.1' if legend is None else str(legend)
+
+ # Check if curve was previously active
+ wasActive = self.getActiveCurve(just_legend=True) == legend
+
+ # Create/Update curve object
+ curve = self.getCurve(legend)
+ if curve is None:
+ # No previous curve, create a default one and add it to the plot
+ curve = items.Curve() if histogram is None else items.Histogram()
+ curve._setLegend(legend)
+ # Set default color, linestyle and symbol
+ default_color, default_linestyle = self._getColorAndStyle()
+ curve.setColor(default_color)
+ curve.setLineStyle(default_linestyle)
+ curve.setSymbol(self._defaultPlotPoints)
+ self._add(curve)
+
+ # Override previous/default values with provided ones
+ curve.setInfo(info)
+ if color is not None:
+ curve.setColor(color)
+ if symbol is not None:
+ curve.setSymbol(symbol)
+ if linewidth is not None:
+ curve.setLineWidth(linewidth)
+ if linestyle is not None:
+ curve.setLineStyle(linestyle)
+ if xlabel is not None:
+ curve._setXLabel(xlabel)
+ if ylabel is not None:
+ curve._setYLabel(ylabel)
+ if yaxis is not None:
+ curve.setYAxis(yaxis)
+ if z is not None:
+ curve.setZValue(z)
+ if selectable is not None:
+ curve._setSelectable(selectable)
+ if fill is not None:
+ curve.setFill(fill)
+
+ # Set curve data
+ # If errors not provided, reuse previous ones
+ # TODO: Issue if size of data change but not that of errors
+ if xerror is None:
+ xerror = curve.getXErrorData(copy=False)
+ if yerror is None:
+ yerror = curve.getYErrorData(copy=False)
+
+ curve.setData(x, y, xerror, yerror, copy=copy)
+
+ if replace: # Then remove all other curves
+ for c in self.getAllCurves(withhidden=True):
+ if c is not curve:
+ self._remove(c)
+
+ self.notify(
+ 'contentChanged', action='add', kind='curve', legend=legend)
+
+ if wasActive:
+ self.setActiveCurve(curve.getLegend())
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addHistogram(self,
+ histogram,
+ edges,
+ legend=None,
+ color=None,
+ fill=None,
+ align='center',
+ resetzoom=True,
+ copy=True):
+ """Add an histogram to the graph.
+
+ This is NOT computing the histogram, this method takes as parameter
+ already computed histogram values.
+
+ Histogram are uniquely identified by their legend.
+ To add multiple histograms, call :meth:`addHistogram` multiple times
+ with different legend argument.
+
+ When histogram parameters are not provided, if an histogram with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str legend:
+ The legend to be associated to the histogram (or None)
+ :param color: color to be used
+ :type color: str ("#RRGGBB") or RGB unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this histogram
+ """
+ legend = 'Unnamed histogram' if legend is None else str(legend)
+
+ # Create/Update histogram object
+ histo = self.getHistogram(legend)
+ if histo is None:
+ # No previous histogram, create a default one and
+ # add it to the plot
+ histo = items.Histogram()
+ histo._setLegend(legend)
+ histo.setColor(self._getColorAndStyle()[0])
+ self._add(histo)
+
+ # Override previous/default values with provided ones
+ if color is not None:
+ histo.setColor(color)
+ if fill is not None:
+ histo.setFill(fill)
+
+ # Set histogram data
+ histo.setData(histogram, edges, align=align, copy=copy)
+
+ self.notify(
+ 'contentChanged', action='add', kind='histogram', legend=legend)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addImage(self, data, legend=None, info=None,
+ replace=True, replot=None,
+ xScale=None, yScale=None, z=None,
+ selectable=None, draggable=None,
+ colormap=None, pixmap=None,
+ xlabel=None, ylabel=None,
+ origin=None, scale=None,
+ resetzoom=True, copy=True, **kw):
+ """Add a 2D dataset or an image to the plot.
+
+ It displays either an array of data using a colormap or a RGB(A) image.
+
+ Images are uniquely identified by their legend.
+ To add multiple images, call :meth:`addImage` multiple times with
+ different legend argument.
+ To replace/update an existing image, call :meth:`addImage` with the
+ existing image legend.
+
+ When image parameters are not provided, if an image with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray data: (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ :param str legend: The legend to be associated to the image (or None)
+ :param info: User-defined information associated to the image
+ :param bool replace: True (default) to delete already existing images
+ :param int z: Layer on which to draw the image (default: 0)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the image can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the image can be moved.
+ (default: False)
+ :param dict colormap: Description of the colormap to use (or None)
+ This is ignored if data is a RGB(A) image.
+ See :mod:`Plot` for the documentation
+ of the colormap dict.
+ :param pixmap: Pixmap representation of the data (if any)
+ :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default)
+ :param str xlabel: X axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param str ylabel: Y axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param origin: (origin X, origin Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (0., 0.)
+ :type origin: float or 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (1., 1.)
+ :type scale: float or 2-tuple of float
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this image
+ """
+ # Deprecation warnings
+ if xScale is not None or yScale is not None:
+ _logger.warning(
+ 'addImage deprecated xScale and yScale arguments,'
+ 'use origin, scale arguments instead.')
+ if origin is None and scale is None:
+ origin = xScale[0], yScale[0]
+ scale = xScale[1], yScale[1]
+ else:
+ _logger.warning(
+ 'addCurve: xScale, yScale and origin, scale arguments'
+ ' are conflicting. xScale and yScale are ignored.'
+ ' Use only origin, scale arguments.')
+
+ if replot is not None:
+ _logger.warning(
+ 'addImage deprecated replot argument, use resetzoom instead')
+ resetzoom = replot and resetzoom
+
+ if kw:
+ _logger.warning('addImage: deprecated extra arguments')
+
+ legend = "Unnamed Image 1.1" if legend is None else str(legend)
+
+ # Check if image was previously active
+ wasActive = self.getActiveImage(just_legend=True) == legend
+
+ data = numpy.array(data, copy=False)
+ assert data.ndim in (2, 3)
+
+ image = self.getImage(legend)
+ if image is not None and image.getData(copy=False).ndim != data.ndim:
+ # Update a data image with RGBA image or the other way around:
+ # Remove previous image
+ # In this case, we don't retrieve defaults from the previous image
+ self._remove(image)
+ image = None
+
+ if image is None:
+ # No previous image, create a default one and add it to the plot
+ if data.ndim == 2:
+ image = items.ImageData()
+ image.setColormap(self.getDefaultColormap())
+ else:
+ image = items.ImageRgba()
+ image._setLegend(legend)
+ self._add(image)
+
+ # Override previous/default values with provided ones
+ image.setInfo(info)
+ if origin is not None:
+ image.setOrigin(origin)
+ if scale is not None:
+ image.setScale(scale)
+ if z is not None:
+ image.setZValue(z)
+ if selectable is not None:
+ image._setSelectable(selectable)
+ if draggable is not None:
+ image._setDraggable(draggable)
+ if colormap is not None and isinstance(image, items.ColormapMixIn):
+ image.setColormap(colormap)
+ if xlabel is not None:
+ image._setXLabel(xlabel)
+ if ylabel is not None:
+ image._setYLabel(ylabel)
+
+ if data.ndim == 2:
+ image.setData(data, alternative=pixmap, copy=copy)
+ else: # RGB(A) image
+ if pixmap is not None:
+ _logger.warning(
+ 'addImage: pixmap argument ignored when data is RGB(A)')
+ image.setData(data, copy=copy)
+
+ if replace:
+ for img in self.getAllImages():
+ if img is not image:
+ self._remove(img)
+
+ if len(self.getAllImages()) == 1 or wasActive:
+ self.setActiveImage(legend)
+
+ self.notify(
+ 'contentChanged', action='add', kind='image', legend=legend)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addScatter(self, x, y, value, legend=None, colormap=None,
+ info=None, symbol=None, xerror=None, yerror=None,
+ z=None, copy=True):
+ """Add a (x, y, value) scatter to the graph.
+
+ Scatters are uniquely identified by their legend.
+ To add multiple scatters, call :meth:`addScatter` multiple times with
+ different legend argument.
+ To replace/update an existing scatter, call :meth:`addScatter` with the
+ existing scatter legend.
+
+ When scatter parameters are not provided, if a scatter with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param numpy.ndarray value: The data value associated with each point
+ :param str legend: The legend to be associated to the scatter (or None)
+ :param dict colormap: The colormap to be used for the scatter (or None)
+ See :mod:`Plot` for the documentation
+ of the colormap dict.
+ :param info: User-defined information associated to the curve
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the scatter (default: 1)
+ This allows to control the overlay.
+
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this scatter
+ """
+ legend = 'Unnamed scatter 1.1' if legend is None else str(legend)
+
+ # Check if scatter was previously active
+ wasActive = self._getActiveItem(kind='scatter',
+ just_legend=True) == legend
+
+ # Create/Update curve object
+ scatter = self._getItem(kind='scatter', legend=legend)
+ if scatter is None:
+ # No previous scatter, create a default one and add it to the plot
+ scatter = items.Scatter()
+ scatter._setLegend(legend)
+ scatter.setColormap(self.getDefaultColormap())
+ self._add(scatter)
+
+ # Override previous/default values with provided ones
+ scatter.setInfo(info)
+ if symbol is not None:
+ scatter.setSymbol(symbol)
+ if z is not None:
+ scatter.setZValue(z)
+ if colormap is not None:
+ scatter.setColormap(colormap)
+
+ # Set scatter data
+ # If errors not provided, reuse previous ones
+ if xerror is None:
+ xerror = scatter.getXErrorData(copy=False)
+ if xerror is not None and len(xerror) != len(x):
+ xerror = None
+ if yerror is None:
+ yerror = scatter.getYErrorData(copy=False)
+ if yerror is not None and len(yerror) != len(y):
+ yerror = None
+
+ scatter.setData(x, y, value, xerror, yerror, copy=copy)
+
+ self.notify(
+ 'contentChanged', action='add', kind='scatter', legend=legend)
+
+ if len(self._getItems(kind="scatter")) == 1 or wasActive:
+ self._setActiveItem('scatter', scatter.getLegend())
+
+ return legend
+
+ def addItem(self, xdata, ydata, legend=None, info=None,
+ replace=False,
+ shape="polygon", color='black', fill=True,
+ overlay=False, z=None, **kw):
+ """Add an item (i.e. a shape) to the plot.
+
+ Items are uniquely identified by their legend.
+ To add multiple items, call :meth:`addItem` multiple times with
+ different legend argument.
+ To replace/update an existing item, call :meth:`addItem` with the
+ existing item legend.
+
+ :param numpy.ndarray xdata: The X coords of the points of the shape
+ :param numpy.ndarray ydata: The Y coords of the points of the shape
+ :param str legend: The legend to be associated to the item
+ :param info: User-defined information associated to the item
+ :param bool replace: True (default) to delete already existing images
+ :param str shape: Type of item to be drawn in
+ hline, polygon (the default), rectangle, vline,
+ polylines
+ :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool fill: True (the default) to fill the shape
+ :param bool overlay: True if item is an overlay (Default: False).
+ This allows for rendering optimization if this
+ item is changed often.
+ :param int z: Layer on which to draw the item (default: 2)
+ :returns: The key string identify this item
+ """
+ # expected to receive the same parameters as the signal
+
+ if kw:
+ _logger.warning('addItem deprecated parameters: %s', str(kw))
+
+ legend = "Unnamed Item 1.1" if legend is None else str(legend)
+
+ z = int(z) if z is not None else 2
+
+ if replace:
+ self.remove(kind='item')
+ else:
+ self.remove(legend, kind='item')
+
+ item = items.Shape(shape)
+ item._setLegend(legend)
+ item.setInfo(info)
+ item.setColor(color)
+ item.setFill(fill)
+ item.setOverlay(overlay)
+ item.setZValue(z)
+ item.setPoints(numpy.array((xdata, ydata)).T)
+
+ self._add(item)
+
+ self.notify('contentChanged', action='add', kind='item', legend=legend)
+
+ return legend
+
+ def addXMarker(self, x, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ **kw):
+ """Add a vertical line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addXMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float x: Position of the marker on the X axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display on the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :return: The key string identify this marker
+ """
+ if kw:
+ _logger.warning(
+ 'addXMarker deprecated extra parameters: %s', str(kw))
+
+ return self._addMarker(x=x, y=None, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint)
+
+ def addYMarker(self, y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ **kw):
+ """Add a horizontal line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addYMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :return: The key string identify this marker
+ """
+ if kw:
+ _logger.warning(
+ 'addYMarker deprecated extra parameters: %s', str(kw))
+
+ return self._addMarker(x=None, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint)
+
+ def addMarker(self, x, y, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ symbol='+',
+ constraint=None,
+ **kw):
+ """Add a point marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float x: Position of the marker on the X axis in data
+ coordinates
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param str symbol: Symbol representing the marker in::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross (the default)
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :return: The key string identify this marker
+ """
+ if kw:
+ _logger.warning(
+ 'addMarker deprecated extra parameters: %s', str(kw))
+
+ if x is None:
+ xmin, xmax = self.getGraphXLimits()
+ x = 0.5 * (xmax + xmin)
+
+ if y is None:
+ ymin, ymax = self.getGraphYLimits()
+ y = 0.5 * (ymax + ymin)
+
+ return self._addMarker(x=x, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=symbol, constraint=constraint)
+
+ def _addMarker(self, x, y, legend,
+ text, color,
+ selectable, draggable,
+ symbol, constraint):
+ """Common method for adding point, vline and hline marker.
+
+ See :meth:`addMarker` for argument documentation.
+ """
+ assert (x, y) != (None, None)
+
+ if legend is None: # Find an unused legend
+ markerLegends = self._getAllMarkers(just_legend=True)
+ for index in itertools.count():
+ legend = "Unnamed Marker %d" % index
+ if legend not in markerLegends:
+ break # Keep this legend
+ legend = str(legend)
+
+ if x is None:
+ markerClass = items.YMarker
+ elif y is None:
+ markerClass = items.XMarker
+ else:
+ markerClass = items.Marker
+
+ # Create/Update marker object
+ marker = self._getMarker(legend)
+ if marker is not None and not isinstance(marker, markerClass):
+ _logger.warning('Adding marker with same legend'
+ ' but different type replaces it')
+ self._remove(marker)
+ marker = None
+
+ if marker is None:
+ # No previous marker, create one
+ marker = markerClass()
+ marker._setLegend(legend)
+ self._add(marker)
+
+ if text is not None:
+ marker.setText(text)
+ if color is not None:
+ marker.setColor(color)
+ if selectable is not None:
+ marker._setSelectable(selectable)
+ if draggable is not None:
+ marker._setDraggable(draggable)
+ if symbol is not None:
+ marker.setSymbol(symbol)
+
+ # TODO to improve, but this ensure constraint is applied
+ marker.setPosition(x, y)
+ if constraint is not None:
+ marker._setConstraint(constraint)
+ marker.setPosition(x, y)
+
+ self.notify(
+ 'contentChanged', action='add', kind='marker', legend=legend)
+
+ return legend
+
+ # Hide
+
+ def isCurveHidden(self, legend):
+ """Returns True if the curve associated to legend is hidden, else False
+
+ :param str legend: The legend key identifying the curve
+ :return: True if the associated curve is hidden, False otherwise
+ """
+ curve = self._getItem('curve', legend)
+ return curve is not None and not curve.isVisible()
+
+ def hideCurve(self, legend, flag=True, replot=None):
+ """Show/Hide the curve associated to legend.
+
+ Even when hidden, the curve is kept in the list of curves.
+
+ :param str legend: The legend associated to the curve to be hidden
+ :param bool flag: True (default) to hide the curve, False to show it
+ """
+ if replot is not None:
+ _logger.warning('hideCurve deprecated replot parameter')
+
+ curve = self._getItem('curve', legend)
+ if curve is None:
+ _logger.warning('Curve not in plot: %s', legend)
+ return
+
+ isVisible = not flag
+ if isVisible != curve.isVisible():
+ curve.setVisible(isVisible)
+
+ # Remove
+
+ ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram'
+
+ def remove(self, legend=None, kind=ITEM_KINDS):
+ """Remove one or all element(s) of the given legend and kind.
+
+ Examples:
+
+ - ``remove()`` clears the plot
+ - ``remove(kind='curve')`` removes all curves from the plot
+ - ``remove('myCurve', kind='curve')`` removes the curve with
+ legend 'myCurve' from the plot.
+ - ``remove('myImage, kind='image')`` removes the image with
+ legend 'myImage' from the plot.
+ - ``remove('myImage')`` removes elements (for instance curve, image,
+ item and marker) with legend 'myImage'.
+
+ :param str legend: The legend associated to the element to remove,
+ or None to remove
+ :param kind: The kind of elements to remove from the plot.
+ In: 'all', 'curve', 'image', 'item', 'marker'.
+ By default, it removes all kind of elements.
+ :type kind: str or tuple of str to specify multiple kinds.
+ """
+ if kind is 'all': # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ if legend is None: # This is a clear
+ # Clear each given kind
+ for aKind in kind:
+ for legend in self._getItems(
+ kind=aKind, just_legend=True, withhidden=True):
+ self.remove(legend=legend, kind=aKind)
+
+ else: # This is removing a single element
+ # Remove each given kind
+ for aKind in kind:
+ item = self._getItem(aKind, legend)
+ if item is not None:
+ if aKind in ('curve', 'image'):
+ if self._getActiveItem(aKind) == item:
+ # Reset active item
+ self._setActiveItem(aKind, None)
+
+ self._remove(item)
+
+ if (aKind == 'curve' and
+ not self.getAllCurves(just_legend=True,
+ withhidden=True)):
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ self.notify('contentChanged', action='remove',
+ kind=aKind, legend=legend)
+
+ def removeCurve(self, legend):
+ """Remove the curve associated to legend from the graph.
+
+ :param str legend: The legend associated to the curve to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='curve')
+
+ def removeImage(self, legend):
+ """Remove the image associated to legend from the graph.
+
+ :param str legend: The legend associated to the image to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='image')
+
+ def removeItem(self, legend):
+ """Remove the item associated to legend from the graph.
+
+ :param str legend: The legend associated to the item to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='item')
+
+ def removeMarker(self, legend):
+ """Remove the marker associated to legend from the graph.
+
+ :param str legend: The legend associated to the marker to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='marker')
+
+ # Clear
+
+ def clear(self):
+ """Remove everything from the plot."""
+ self.remove()
+
+ def clearCurves(self):
+ """Remove all the curves from the plot."""
+ self.remove(kind='curve')
+
+ def clearImages(self):
+ """Remove all the images from the plot."""
+ self.remove(kind='image')
+
+ def clearItems(self):
+ """Remove all the items from the plot. """
+ self.remove(kind='item')
+
+ def clearMarkers(self):
+ """Remove all the markers from the plot."""
+ self.remove(kind='marker')
+
+ # Interaction
+
+ def getGraphCursor(self):
+ """Returns the state of the crosshair cursor.
+
+ See :meth:`setGraphCursor`.
+
+ :return: None if the crosshair cursor is not active,
+ else a tuple (color, linewidth, linestyle).
+ """
+ return self._cursorConfiguration
+
+ def setGraphCursor(self, flag=False, color='black',
+ linewidth=1, linestyle='-'):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ The crosshair cursor is hidden by default.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in Colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array
+ (Default: black).
+ :param int linewidth: The width of the lines of the crosshair
+ (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line (the default)
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ """
+ if flag:
+ self._cursorConfiguration = color, linewidth, linestyle
+ else:
+ self._cursorConfiguration = None
+
+ self._backend.setGraphCursor(flag=flag, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ self._setDirtyPlot()
+ self.notify('setGraphCursor',
+ state=self._cursorConfiguration is not None)
+
+ def pan(self, direction, factor=0.1):
+ """Pan the graph in the given direction by the given factor.
+
+ Warning: Pan of right Y axis not implemented!
+
+ :param str direction: One of 'up', 'down', 'left', 'right'.
+ :param float factor: Proportion of the range used to pan the graph.
+ Must be strictly positive.
+ """
+ assert direction in ('up', 'down', 'left', 'right')
+ assert factor > 0.
+
+ if direction in ('left', 'right'):
+ xFactor = factor if direction == 'right' else - factor
+ xMin, xMax = self.getGraphXLimits()
+
+ xMin, xMax = _utils.applyPan(xMin, xMax, xFactor,
+ self.isXAxisLogarithmic())
+ self.setGraphXLimits(xMin, xMax)
+
+ else: # direction in ('up', 'down')
+ sign = -1. if self.isYAxisInverted() else 1.
+ yFactor = sign * (factor if direction == 'up' else -factor)
+ yMin, yMax = self.getGraphYLimits()
+ yIsLog = self.isYAxisLogarithmic()
+
+ yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog)
+ self.setGraphYLimits(yMin, yMax, axis='left')
+
+ y2Min, y2Max = self.getGraphYLimits(axis='right')
+
+ y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog)
+ self.setGraphYLimits(y2Min, y2Max, axis='right')
+
+ # Active Curve/Image
+
+ def isActiveCurveHandling(self):
+ """Returns True if active curve selection is enabled."""
+ return self._activeCurveHandling
+
+ def setActiveCurveHandling(self, flag=True):
+ """Enable/Disable active curve selection.
+
+ :param bool flag: True (the default) to enable active curve selection.
+ """
+ if not flag:
+ self.setActiveCurve(None) # Reset active curve
+
+ self._activeCurveHandling = bool(flag)
+
+ def getActiveCurveColor(self):
+ """Get the color used to display the currently active curve.
+
+ See :meth:`setActiveCurveColor`.
+ """
+ return self._activeCurveColor
+
+ def setActiveCurveColor(self, color="#000000"):
+ """Set the color to use to display the currently active curve.
+
+ :param str color: Color of the active curve,
+ e.g., 'blue', 'b', '#FF0000' (Default: 'black')
+ """
+ if color is None:
+ color = "black"
+ if color in self.colorDict:
+ color = self.colorDict[color]
+ self._activeCurveColor = color
+
+ def getActiveCurve(self, just_legend=False):
+ """Return the currently active curve.
+
+ It returns None in case of not having an active curve.
+
+ :param bool just_legend: True to get the legend of the curve,
+ False (the default) to get the curve data
+ and info.
+ :return: Active curve's legend or corresponding
+ :class:`.items.Curve`
+ :rtype: str or :class:`.items.Curve` or None
+ """
+ if not self.isActiveCurveHandling():
+ return None
+
+ return self._getActiveItem(kind='curve', just_legend=just_legend)
+
+ def setActiveCurve(self, legend, replot=None):
+ """Make the curve associated to legend the active curve.
+
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ if replot is not None:
+ _logger.warning('setActiveCurve deprecated replot parameter')
+
+ if not self.isActiveCurveHandling():
+ return
+
+ return self._setActiveItem(kind='curve', legend=legend)
+
+ def getActiveImage(self, just_legend=False):
+ """Returns the currently active image.
+
+ It returns None in case of not having an active image.
+
+ :param bool just_legend: True to get the legend of the image,
+ False (the default) to get the image data
+ and info.
+ :return: Active image's legend or corresponding image object
+ :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
+ or None
+ """
+ return self._getActiveItem(kind='image', just_legend=just_legend)
+
+ def setActiveImage(self, legend, replot=None):
+ """Make the image associated to legend the active image.
+
+ :param str legend: The legend associated to the image
+ or None to have no active image.
+ """
+ if replot is not None:
+ _logger.warning('setActiveImage deprecated replot parameter')
+
+ return self._setActiveItem(kind='image', legend=legend)
+
+ def _getActiveItem(self, kind, just_legend=False):
+ """Return the currently active item of that kind if any
+
+ :param str kind: Type of item: 'curve', 'scatter' or 'image'
+ :param bool just_legend: True to get the legend,
+ False (default) to get the item
+ :return: legend or item or None if no active item
+ """
+ assert kind in ('curve', 'scatter', 'image')
+
+ if self._activeLegend[kind] is None:
+ return None
+
+ if (self._activeLegend[kind], kind) not in self._content:
+ self._activeLegend[kind] = None
+ return None
+
+ if just_legend:
+ return self._activeLegend[kind]
+ else:
+ return self._getItem(kind, self._activeLegend[kind])
+
+ def _setActiveItem(self, kind, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param str kind: Type of item: 'curve' or 'image'
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ assert kind in ('curve', 'image', 'scatter')
+
+ xLabel = self._defaultLabels['x']
+ yLabel = self._defaultLabels['y']
+ yRightLabel = self._defaultLabels['yright']
+
+ oldActiveItem = self._getActiveItem(kind=kind)
+
+ # Curve specific: Reset highlight of previous active curve
+ if kind == 'curve' and oldActiveItem is not None:
+ oldActiveItem.setHighlighted(False)
+
+ if legend is None:
+ self._activeLegend[kind] = None
+ else:
+ legend = str(legend)
+ item = self._getItem(kind, legend)
+ if item is None:
+ _logger.warning("This %s does not exist: %s", kind, legend)
+ self._activeLegend[kind] = None
+ else:
+ self._activeLegend[kind] = legend
+
+ # Curve specific: handle highlight
+ if kind == 'curve':
+ item.setHighlightedColor(self.getActiveCurveColor())
+ item.setHighlighted(True)
+
+ if isinstance(item, items.LabelsMixIn):
+ if item.getXLabel() is not None:
+ xLabel = item.getXLabel()
+ if item.getYLabel() is not None:
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ yRightLabel = item.getYLabel()
+ else:
+ yLabel = item.getYLabel()
+
+ # Store current labels and update plot
+ self._currentLabels['x'] = xLabel
+ self._currentLabels['y'] = yLabel
+ self._currentLabels['yright'] = yRightLabel
+
+ self._backend.setGraphXLabel(xLabel)
+ self._backend.setGraphYLabel(yLabel, axis='left')
+ self._backend.setGraphYLabel(yRightLabel, axis='right')
+
+ self._setDirtyPlot()
+
+ activeLegend = self._activeLegend[kind]
+ if oldActiveItem is not None or activeLegend is not None:
+ if oldActiveItem is None:
+ oldActiveLegend = None
+ else:
+ oldActiveLegend = oldActiveItem.getLegend()
+ self.notify(
+ 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ updated=oldActiveLegend != activeLegend,
+ previous=oldActiveLegend,
+ legend=activeLegend)
+
+ return activeLegend
+
+ # Getters
+
+ def getAllCurves(self, just_legend=False, withhidden=False):
+ """Returns all curves legend or info and data.
+
+ It returns an empty list in case of not having any curve.
+
+ If just_legend is False, it returns a list of :class:`items.Curve`
+ objects describing the curves.
+ If just_legend is True, it returns a list of curves' legend.
+
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of curves' legend or :class:`.items.Curve`
+ :rtype: list of str or list of :class:`.items.Curve`
+ """
+ return self._getItems(kind='curve',
+ just_legend=just_legend,
+ withhidden=withhidden)
+
+ def getCurve(self, legend=None):
+ """Get the object describing a specific curve.
+
+ It returns None in case no matching curve is found.
+
+ :param str legend:
+ The legend identifying the curve.
+ If not provided or None (the default), the active curve is returned
+ or if there is no active curve, the latest updated curve that is
+ not hidden is returned if there are curves in the plot.
+ :return: None or :class:`.items.Curve` object
+ """
+ return self._getItem(kind='curve', legend=legend)
+
+ def getAllImages(self, just_legend=False):
+ """Returns all images legend or objects.
+
+ It returns an empty list in case of not having any image.
+
+ If just_legend is False, it returns a list of :class:`items.ImageBase`
+ objects describing the images.
+ If just_legend is True, it returns a list of legends.
+
+ :param bool just_legend: True to get the legend of the images,
+ False (the default) to get the images'
+ object.
+ :return: list of images' legend or :class:`.items.ImageBase`
+ :rtype: list of str or list of :class:`.items.ImageBase`
+ """
+ return self._getItems(kind='image',
+ just_legend=just_legend,
+ withhidden=True)
+
+ def getImage(self, legend=None):
+ """Get the object describing a specific image.
+
+ It returns None in case no matching image is found.
+
+ :param str legend:
+ The legend identifying the image.
+ If not provided or None (the default), the active image is returned
+ or if there is no active image, the latest updated image
+ is returned if there are images in the plot.
+ :return: None or :class:`.items.ImageBase` object
+ """
+ return self._getItem(kind='image', legend=legend)
+
+ def getScatter(self, legend=None):
+ """Get the object describing a specific scatter.
+
+ It returns None in case no matching scatter is found.
+
+ :param str legend:
+ The legend identifying the scatter.
+ If not provided or None (the default), the active scatter is
+ returned or if there is no active scatter, the latest updated
+ scatter is returned if there are scatters in the plot.
+ :return: None or :class:`.items.Scatter` object
+ """
+ return self._getItem(kind='scatter', legend=legend)
+
+ def getHistogram(self, legend=None):
+ """Get the object describing a specific histogram.
+
+ It returns None in case no matching histogram is found.
+
+ :param str legend:
+ The legend identifying the histogram.
+ If not provided or None (the default), the latest updated scatter
+ is returned if there are histograms in the plot.
+ :return: None or :class:`.items.Histogram` object
+ """
+ return self._getItem(kind='histogram', legend=legend)
+
+ def _getItems(self, kind, just_legend=False, withhidden=False):
+ """Retrieve all items of a kind in the plot
+
+ :param str kind: Type of item: 'curve' or 'image'
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of legends or item objects
+ """
+ assert kind in self.ITEM_KINDS
+ output = []
+ for (legend, type_), item in self._content.items():
+ if type_ == kind and (withhidden or item.isVisible()):
+ output.append(legend if just_legend else item)
+ return output
+
+ def _getItem(self, kind, legend=None):
+ """Get an item from the plot: either an image or a curve.
+
+ Returns None if no match found
+
+ :param str kind: Type of item: 'curve' or 'image'
+ :param str legend: Legend of the item or
+ None to get active or last item
+ :return: Object describing the item or None
+ """
+ assert kind in self.ITEM_KINDS
+
+ if legend is not None:
+ return self._content.get((legend, kind), None)
+ else:
+ if kind in ('curve', 'image', 'scatter'):
+ item = self._getActiveItem(kind=kind)
+ if item is not None: # Return active item if available
+ return item
+ # Return last visible item if any
+ allItems = self._getItems(
+ kind=kind, just_legend=False, withhidden=False)
+ return allItems[-1] if allItems else None
+
+ # Limits
+
+ def _notifyLimitsChanged(self):
+ """Send an event when plot area limits are changed."""
+ xRange = self.getGraphXLimits()
+ yRange = self.getGraphYLimits(axis='left')
+ y2Range = self.getGraphYLimits(axis='right')
+ event = PlotEvents.prepareLimitsChangedSignal(
+ id(self.getWidgetHandle()), xRange, yRange, y2Range)
+ self.notify(**event)
+
+ def _checkLimits(self, min_, max_, axis):
+ """Makes sure axis range is not empty
+
+ :param float min_: Min axis value
+ :param float max_: Max axis value
+ :param str axis: 'x', 'y' or 'y2' the axis to deal with
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ if max_ < min_:
+ _logger.info('%s axis: max < min, inverting limits.', axis)
+ min_, max_ = max_, min_
+ elif max_ == min_:
+ _logger.info('%s axis: max == min, expanding limits.', axis)
+ if min_ == 0.:
+ min_, max_ = -0.1, 0.1
+ elif min_ < 0:
+ min_, max_ = min_ * 1.1, min_ * 0.9
+ else: # xmin > 0
+ min_, max_ = min_ * 0.9, min_ * 1.1
+
+ return min_, max_
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self._backend.getGraphXLimits()
+
+ def setGraphXLimits(self, xmin, xmax, replot=None):
+ """Set the graph X (bottom) limits.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ if replot is not None:
+ _logger.warning('setGraphXLimits deprecated replot parameter')
+
+ xmin, xmax = self._checkLimits(xmin, xmax, axis='x')
+
+ self._backend.setGraphXLimits(xmin, xmax)
+ self._setDirtyPlot()
+
+ self._notifyLimitsChanged()
+
+ def getGraphYLimits(self, axis='left'):
+ """Get the graph Y limits.
+
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ :return: Minimum and maximum values of the X axis
+ """
+ assert axis in ('left', 'right')
+ return self._backend.getGraphYLimits(axis)
+
+ def setGraphYLimits(self, ymin, ymax, axis='left', replot=None):
+ """Set the graph Y limits.
+
+ :param float ymin: minimum bottom axis value
+ :param float ymax: maximum bottom axis value
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ """
+ if replot is not None:
+ _logger.warning('setGraphYLimits deprecated replot parameter')
+
+ assert axis in ('left', 'right')
+
+ ymin, ymax = self._checkLimits(ymin, ymax,
+ axis='y' if axis == 'left' else 'y2')
+
+ self._backend.setGraphYLimits(ymin, ymax, axis)
+ self._setDirtyPlot()
+
+ self._notifyLimitsChanged()
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ If y2min or y2max is None, the right Y axis limits are not updated.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value or None (the default)
+ :param float y2max: maximum right axis value or None (the default)
+ """
+ # Deal with incorrect values
+ xmin, xmax = self._checkLimits(xmin, xmax, axis='x')
+ ymin, ymax = self._checkLimits(ymin, ymax, axis='y')
+
+ if y2min is None or y2max is None:
+ # if one limit is None, both are ignored
+ y2min, y2max = None, None
+ else:
+ y2min, y2max = self._checkLimits(y2min, y2max, axis='y2')
+
+ self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+ self._setDirtyPlot()
+ self._notifyLimitsChanged()
+
+ # Title and labels
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._graphTitle
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self._graphTitle = str(title)
+ self._backend.setGraphTitle(title)
+ self._setDirtyPlot()
+
+ def getGraphXLabel(self):
+ """Return the current X axis label as a str."""
+ return self._currentLabels['x']
+
+ def setGraphXLabel(self, label="X"):
+ """Set the plot X axis label.
+
+ The provided label can be temporarily replaced by the X label of the
+ active curve if any.
+
+ :param str label: The X axis label (default: 'X')
+ """
+ self._defaultLabels['x'] = label
+ self._currentLabels['x'] = label
+ self._backend.setGraphXLabel(label)
+ self._setDirtyPlot()
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current Y axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ assert axis in ('left', 'right')
+
+ return self._currentLabels['y' if axis == 'left' else 'yright']
+
+ def setGraphYLabel(self, label="Y", axis='left'):
+ """Set the plot Y axis label.
+
+ The provided label can be temporarily replaced by the Y label of the
+ active curve if any.
+
+ :param str label: The Y axis label (default: 'Y')
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ assert axis in ('left', 'right')
+
+ if axis == 'left':
+ self._defaultLabels['y'] = label
+ self._currentLabels['y'] = label
+ else:
+ self._defaultLabels['yright'] = label
+ self._currentLabels['yright'] = label
+
+ self._backend.setGraphYLabel(label, axis=axis)
+ self._setDirtyPlot()
+
+ # Axes
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ flag = bool(flag)
+ self._backend.setYAxisInverted(flag)
+ self._setDirtyPlot()
+ self.notify('setYAxisInverted', state=flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self._backend.isYAxisInverted()
+
+ def isXAxisLogarithmic(self):
+ """Return True if X axis scale is logarithmic, False if linear."""
+ return self._logX
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the bottom X axis scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ if bool(flag) == self._logX:
+ return
+ self._logX = bool(flag)
+
+ self._backend.setXAxisLogarithmic(self._logX)
+
+ # TODO hackish way of forcing update of curves and images
+ for curve in self.getAllCurves():
+ curve._updated()
+ for image in self.getAllImages():
+ image._updated()
+ self._invalidateDataRange()
+
+ self.resetZoom()
+ self.notify('setXAxisLogarithmic', state=self._logX)
+
+ def isYAxisLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self._logY
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ if bool(flag) == self._logY:
+ return
+ self._logY = bool(flag)
+
+ self._backend.setYAxisLogarithmic(self._logY)
+
+ # TODO hackish way of forcing update of curves and images
+ for curve in self.getAllCurves():
+ curve._updated()
+ for image in self.getAllImages():
+ image._updated()
+ self._invalidateDataRange()
+
+ self.resetZoom()
+ self.notify('setYAxisLogarithmic', state=self._logY)
+
+ def isXAxisAutoScale(self):
+ """Return True if X axis is automatically adjusting its limits."""
+ return self._xAutoScale
+
+ def setXAxisAutoScale(self, flag=True):
+ """Set the X axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._xAutoScale = bool(flag)
+ self.notify('setXAxisAutoScale', state=self._xAutoScale)
+
+ def isYAxisAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self._yAutoScale
+
+ def setYAxisAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._yAutoScale = bool(flag)
+ self.notify('setYAxisAutoScale', state=self._yAutoScale)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self._backend.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ flag = bool(flag)
+ self._backend.setKeepDataAspectRatio(flag=flag)
+ self._setDirtyPlot()
+ self.resetZoom()
+ self.notify('setKeepDataAspectRatio', state=flag)
+
+ def getGraphGrid(self):
+ """Return the current grid mode, either None, 'major' or 'both'.
+
+ See :meth:`setGraphGrid`.
+ """
+ return self._grid
+
+ def setGraphGrid(self, which=True):
+ """Set the type of grid to display.
+
+ :param which: None or False to disable the grid,
+ 'major' or True for grid on major ticks (the default),
+ 'both' for grid on both major and minor ticks.
+ :type which: str of bool
+ """
+ assert which in (None, True, False, 'both', 'major')
+ if not which:
+ which = None
+ elif which is True:
+ which = 'major'
+ self._grid = which
+ self._backend.setGraphGrid(which)
+ self._setDirtyPlot()
+ self.notify('setGraphGrid', which=str(which))
+
+ # Defaults
+
+ def isDefaultPlotPoints(self):
+ """Return True if default Curve symbol is 'o', False for no symbol."""
+ return self._defaultPlotPoints == 'o'
+
+ def setDefaultPlotPoints(self, flag):
+ """Set the default symbol of all curves.
+
+ When called, this reset the symbol of all existing curves.
+
+ :param bool flag: True to use 'o' as the default curve symbol,
+ False to use no symbol.
+ """
+ self._defaultPlotPoints = 'o' if flag else ''
+
+ # Reset symbol of all curves
+ curves = self.getAllCurves(just_legend=False, withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setSymbol(self._defaultPlotPoints)
+
+ def isDefaultPlotLines(self):
+ """Return True for line as default line style, False for no line."""
+ return self._plotLines
+
+ def setDefaultPlotLines(self, flag):
+ """Toggle the use of lines as the default curve line style.
+
+ :param bool flag: True to use a line as the default line style,
+ False to use no line as the default line style.
+ """
+ self._plotLines = bool(flag)
+
+ linestyle = '-' if self._plotLines else ' '
+
+ # Reset linestyle of all curves
+ curves = self.getAllCurves(withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setLineStyle(linestyle)
+
+ def getDefaultColormap(self):
+ """Return the default colormap used by :meth:`addImage` as a dict.
+
+ See :mod:`Plot` for the documentation of the colormap dict.
+ """
+ return self._defaultColormap.copy()
+
+ def setDefaultColormap(self, colormap=None):
+ """Set the default colormap used by :meth:`addImage`.
+
+ Setting the default colormap do not change any currently displayed
+ image.
+ It only affects future calls to :meth:`addImage` without the colormap
+ parameter.
+
+ :param dict colormap: The description of the default colormap, or
+ None to set the colormap to a linear autoscale
+ gray colormap.
+ See :mod:`Plot` for the documentation
+ of the colormap dict.
+ """
+ if colormap is None:
+ colormap = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self._defaultColormap = colormap.copy()
+
+ def getSupportedColormaps(self):
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+ """
+ default = ('gray', 'reversed gray',
+ 'temperature',
+ 'red', 'green', 'blue')
+ if matplotlib_cm is None:
+ return default
+ else:
+ maps = [m for m in matplotlib_cm.datad]
+ maps.sort()
+ return default + tuple(maps)
+
+ def _getColorAndStyle(self):
+ color = self.colorList[self._colorIndex]
+ style = self._styleList[self._styleIndex]
+
+ # Loop over color and then styles
+ self._colorIndex += 1
+ if self._colorIndex >= len(self.colorList):
+ self._colorIndex = 0
+ self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
+
+ # If color is the one of active curve, take the next one
+ if color == self.getActiveCurveColor():
+ color, style = self._getColorAndStyle()
+
+ if not self._plotLines:
+ style = ' '
+
+ return color, style
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget the plot is displayed in.
+
+ This widget is owned by the backend.
+ """
+ return self._backend.getWidgetHandle()
+
+ def notify(self, event, **kwargs):
+ """Send an event to the listeners.
+
+ Event are passed to the registered callback as a dict with an 'event'
+ key for backward compatibility with PyMca.
+
+ :param str event: The type of event
+ :param kwargs: The information of the event.
+ """
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self._callback(eventDict)
+
+ def setCallback(self, callbackFunction=None):
+ """Attach a listener to the backend.
+
+ Limitation: Only one listener at a time.
+
+ :param callbackFunction: function accepting a dictionary as input
+ to handle the graph events
+ If None (default), use a default listener.
+ """
+ # TODO allow multiple listeners, keep a weakref on it
+ # allow register listener by event type
+ if callbackFunction is None:
+ callbackFunction = self.graphCallback
+ self._callback = callbackFunction
+
+ def graphCallback(self, ddict=None):
+ """This callback is going to receive all the events from the plot.
+
+ Those events will consist on a dictionary and among the dictionary
+ keys the key 'event' is mandatory to describe the type of event.
+ This default implementation only handles setting the active curve.
+ """
+
+ if ddict is None:
+ ddict = {}
+ _logger.debug("Received dict keys = %s", str(ddict.keys()))
+ _logger.debug(str(ddict))
+ if ddict['event'] in ["legendClicked", "curveClicked"]:
+ if ddict['button'] == "left":
+ self.setActiveCurve(ddict['label'])
+
+ def saveGraph(self, filename, fileFormat=None, dpi=None, **kw):
+ """Save a snapshot of the plot.
+
+ Supported file formats: "png", "svg", "pdf", "ps", "eps",
+ "tif", "tiff", "jpeg", "jpg".
+
+ :param filename: Destination
+ :type filename: str, StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :return: False if cannot save the plot, True otherwise
+ """
+ if kw:
+ _logger.warning('Extra parameters ignored: %s', str(kw))
+
+ if fileFormat is None:
+ if not hasattr(filename, 'lower'):
+ _logger.warning(
+ 'saveGraph cancelled, cannot define file format.')
+ return False
+ else:
+ fileFormat = (filename.split(".")[-1]).lower()
+
+ supportedFormats = ("png", "svg", "pdf", "ps", "eps",
+ "tif", "tiff", "jpeg", "jpg")
+
+ if fileFormat not in supportedFormats:
+ _logger.warning('Unsupported format %s', fileFormat)
+ return False
+ else:
+ self._backend.saveGraph(filename,
+ fileFormat=fileFormat,
+ dpi=dpi)
+ return True
+
+ def getDataMargins(self):
+ """Get the default data margin ratios, see :meth:`setDataMargins`.
+
+ :return: The margin ratios for each side (xMin, xMax, yMin, yMax).
+ :rtype: A 4-tuple of floats.
+ """
+ return self._defaultDataMargins
+
+ def setDataMargins(self, xMinMargin=0., xMaxMargin=0.,
+ yMinMargin=0., yMaxMargin=0.):
+ """Set the default data margins to use in :meth:`resetZoom`.
+
+ Set the default ratios of margins (as floats) to add around the data
+ inside the plot area for each side.
+ """
+ self._defaultDataMargins = (xMinMargin, xMaxMargin,
+ yMinMargin, yMaxMargin)
+
+ def getAutoReplot(self):
+ """Return True if replot is automatically handled, False otherwise.
+
+ See :meth`setAutoReplot`.
+ """
+ return self._autoreplot
+
+ def setAutoReplot(self, autoreplot=True):
+ """Set automatic replot mode.
+
+ When enabled, the plot is redrawn automatically when changed.
+ When disabled, the plot is not redrawn when its content change.
+ Instead, it :meth:`replot` must be called.
+
+ :param bool autoreplot: True to enable it (default),
+ False to disable it.
+ """
+ self._autoreplot = bool(autoreplot)
+
+ # If the plot is dirty before enabling autoreplot,
+ # then _backend.postRedisplay will never be called from _setDirtyPlot
+ if self._autoreplot and self._getDirtyPlot():
+ self._backend.postRedisplay()
+
+ def replot(self):
+ """Redraw the plot immediately."""
+ for item in self._contentToUpdate:
+ item._update(self._backend)
+ self._contentToUpdate.clear()
+ self._backend.replot()
+ self._dirty = False # reset dirty flag
+
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ It automatically scale limits of axes that are in autoscale mode
+ (See :meth:`setXAxisAutoScale`, :meth:`setYAxisAutoScale`).
+ It keeps current limits on axes that are not in autoscale mode.
+
+ Extra margins can be added around the data inside the plot area.
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (Default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ if dataMargins is None:
+ dataMargins = self._defaultDataMargins
+
+ xLimits = self.getGraphXLimits()
+ yLimits = self.getGraphYLimits(axis='left')
+ y2Limits = self.getGraphYLimits(axis='right')
+
+ xAuto = self.isXAxisAutoScale()
+ yAuto = self.isYAxisAutoScale()
+
+ if not xAuto and not yAuto:
+ _logger.debug("Nothing to autoscale")
+ else: # Some axes to autoscale
+
+ # Get data range
+ ranges = self.getDataRange()
+ xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
+ ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
+ if ranges.yright is None:
+ ymin2, ymax2 = None, None
+ else:
+ ymin2, ymax2 = ranges.yright
+
+ # Add margins around data inside the plot area
+ newLimits = list(_utils.addMarginsToLimits(
+ dataMargins,
+ self.isXAxisLogarithmic(),
+ self.isYAxisLogarithmic(),
+ xmin, xmax, ymin, ymax, ymin2, ymax2))
+
+ if self.isKeepDataAspectRatio():
+ # Use limits with margins to keep ratio
+ xmin, xmax, ymin, ymax = newLimits[:4]
+
+ # Compute bbox wth figure aspect ratio
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ plotRatio = plotHeight / plotWidth
+
+ if plotRatio > 0.:
+ dataRatio = (ymax - ymin) / (xmax - xmin)
+ if dataRatio < plotRatio:
+ # Increase y range
+ ycenter = 0.5 * (ymax + ymin)
+ yrange = (xmax - xmin) * plotRatio
+ newLimits[2] = ycenter - 0.5 * yrange
+ newLimits[3] = ycenter + 0.5 * yrange
+
+ elif dataRatio > plotRatio:
+ # Increase x range
+ xcenter = 0.5 * (xmax + xmin)
+ xrange_ = (ymax - ymin) / plotRatio
+ newLimits[0] = xcenter - 0.5 * xrange_
+ newLimits[1] = xcenter + 0.5 * xrange_
+
+ self.setLimits(*newLimits)
+
+ if not xAuto and yAuto:
+ self.setGraphXLimits(*xLimits)
+ elif xAuto and not yAuto:
+ if y2Limits is not None:
+ self.setGraphYLimits(
+ y2Limits[0], y2Limits[1], axis='right')
+ if yLimits is not None:
+ self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')
+
+ self._setDirtyPlot()
+
+ if (xLimits != self.getGraphXLimits() or
+ yLimits != self.getGraphYLimits(axis='left') or
+ y2Limits != self.getGraphYLimits(axis='right')):
+ self._notifyLimitsChanged()
+
+ # Coord conversion
+
+ def dataToPixel(self, x=None, y=None, axis="left", check=True):
+ """Convert a position in data coordinates to a position in pixels.
+
+ :param float x: The X coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param float y: The Y coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: True to return None if outside displayed area,
+ False to convert to pixels anyway
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area and
+ check is True.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ assert axis in ("left", "right")
+
+ xmin, xmax = self.getGraphXLimits()
+ ymin, ymax = self.getGraphYLimits(axis=axis)
+
+ if x is None:
+ x = 0.5 * (xmax + xmin)
+ if y is None:
+ y = 0.5 * (ymax + ymin)
+
+ if check:
+ if x > xmax or x < xmin:
+ return None
+
+ if y > ymax or y < ymin:
+ return None
+
+ return self._backend.dataToPixel(x, y, axis=axis)
+
+ def pixelToData(self, x, y, axis="left", check=False):
+ """Convert a position in pixels to a position in data coordinates.
+
+ :param float x: The X coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param float y: The Y coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: Toggle checking if pixel is in plot area.
+ If False, this method never returns None.
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ assert axis in ("left", "right")
+ return self._backend.pixelToData(x, y, axis=axis, check=check)
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ return self._backend.getPlotBoundsInPixels()
+
+ # Interaction support
+
+ def setGraphCursorShape(self, cursor=None):
+ """Set the cursor shape.
+
+ :param str cursor: Name of the cursor shape
+ """
+ self._backend.setGraphCursorShape(cursor)
+
+ def _pickMarker(self, x, y, test=None):
+ """Pick a marker at the given position.
+
+ To use for interaction implementation.
+
+ :param float x: X position in pixels.
+ :param float y: Y position in pixels.
+ :param test: A callable to call for each picked marker to filter
+ picked markers. If None (default), do not filter markers.
+ """
+ if test is None:
+ def test(mark):
+ return True
+
+ markers = self._backend.pickItems(x, y)
+ legends = [m['legend'] for m in markers if m['kind'] == 'marker']
+
+ for legend in reversed(legends):
+ marker = self._getMarker(legend)
+ if marker is not None and test(marker):
+ return marker
+ return None
+
+ def _getAllMarkers(self, just_legend=False):
+ """Returns all markers' legend or objects
+
+ :param bool just_legend: True to get the legend of the markers,
+ False (the default) to get marker objects.
+ :return: list of legend of list of marker objects
+ :rtype: list of str or list of marker objects
+ """
+ return self._getItems(
+ kind='marker', just_legend=just_legend, withhidden=True)
+
+ def _getMarker(self, legend=None):
+ """Get the object describing a specific marker.
+
+ It returns None in case no matching marker is found
+
+ :param str legend: The legend of the marker to retrieve
+ :rtype: None of marker object
+ """
+ return self._getItem(kind='marker', legend=legend)
+
+ def _pickImageOrCurve(self, x, y, test=None):
+ """Pick an image or a curve at the given position.
+
+ To use for interaction implementation.
+
+ :param float x: X position in pixelsparam float y: Y position in pixels
+ :param test: A callable to call for each picked item to filter
+ picked items. If None (default), do not filter items.
+ """
+ if test is None:
+ def test(i):
+ return True
+
+ allItems = self._backend.pickItems(x, y)
+ allItems = [item for item in allItems
+ if item['kind'] in ['curve', 'image']]
+
+ for item in reversed(allItems):
+ kind, legend = item['kind'], item['legend']
+ if kind == 'curve':
+ curve = self.getCurve(legend)
+ if curve is not None and test(curve):
+ return kind, curve, item['xdata'], item['ydata']
+
+ elif kind == 'image':
+ image = self.getImage(legend)
+ if image is not None and test(image):
+ return kind, image, None
+
+ else:
+ _logger.warning('Unsupported kind: %s', kind)
+
+ return None
+
+ # User event handling #
+
+ def _isPositionInPlotArea(self, x, y):
+ """Project position in pixel to the closest point in the plot area
+
+ :param float x: X coordinate in widget coordinate (in pixel)
+ :param float y: Y coordinate in widget coordinate (in pixel)
+ :return: (x, y) in widget coord (in pixel) in the plot area
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width)
+ yPlot = numpy.clip(y, top, top + height)
+ return xPlot, yPlot
+
+ def onMousePress(self, xPixel, yPixel, btn):
+ """Handle mouse press event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._pressedButtons.append(btn)
+ self._eventHandler.handleEvent('press', xPixel, yPixel, btn)
+
+ def onMouseMove(self, xPixel, yPixel):
+ """Handle mouse move event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ """
+ inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+
+ if self._cursorInPlot != isCursorInPlot:
+ self._cursorInPlot = isCursorInPlot
+ self._eventHandler.handleEvent(
+ 'enter' if self._cursorInPlot else 'leave')
+
+ if isCursorInPlot:
+ # Signal mouse move event
+ dataPos = self.pixelToData(inXPixel, inYPixel)
+ assert dataPos is not None
+
+ btn = self._pressedButtons[-1] if self._pressedButtons else None
+ event = PlotEvents.prepareMouseSignal(
+ 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel)
+ self.notify(**event)
+
+ # Either button was pressed in the plot or cursor is in the plot
+ if isCursorInPlot or self._pressedButtons:
+ self._eventHandler.handleEvent('move', inXPixel, inYPixel)
+
+ def onMouseRelease(self, xPixel, yPixel, btn):
+ """Handle mouse release event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ try:
+ self._pressedButtons.remove(btn)
+ except ValueError:
+ pass
+ else:
+ xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ self._eventHandler.handleEvent('release', xPixel, yPixel, btn)
+
+ def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
+ """Handle mouse wheel event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param float angleInDegrees: Angle corresponding to wheel motion.
+ Positive for movement away from the user,
+ negative for movement toward the user.
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._eventHandler.handleEvent(
+ 'wheel', xPixel, yPixel, angleInDegrees)
+
+ def onMouseLeaveWidget(self):
+ """Handle mouse leave widget event."""
+ if self._cursorInPlot:
+ self._cursorInPlot = False
+ self._eventHandler.handleEvent('leave')
+
+ # Interaction modes #
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ return self._eventHandler.getInteractiveMode()
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None,
+ zoomOnWheel=True, source=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'freeline'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode, sent in drawing events.
+ :param bool zoomOnWheel: Toggle zoom on wheel support
+ :param source: A user-defined object (typically the caller object)
+ that will be send in the interactiveModeChanged event,
+ to identify which object required a mode change.
+ Default: None
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ self._eventHandler.setInteractiveMode(mode, color, shape, label, width)
+ self._eventHandler.zoomOnWheel = zoomOnWheel
+
+ self.notify(
+ 'interactiveModeChanged', source=source)
+
+ # Deprecated #
+
+ def isDrawModeEnabled(self):
+ """Deprecated, use :meth:`getInteractiveMode` instead.
+
+ Return True if the current interactive state is drawing."""
+ _logger.warning(
+ 'isDrawModeEnabled deprecated, use getInteractiveMode instead')
+ return self.getInteractiveMode()['mode'] == 'draw'
+
+ def setDrawModeEnabled(self, flag=True, shape='polygon', label=None,
+ color=None, **kwargs):
+ """Deprecated, use :meth:`setInteractiveMode` instead.
+
+ Set the drawing mode if flag is True and its parameters.
+
+ If flag is False, only item selection is enabled.
+
+ Warning: Zoom and drawing are not compatible and cannot be enabled
+ simultaneously.
+
+ :param bool flag: True to enable drawing and disable zoom and select.
+ :param str shape: Type of item to be drawn in:
+ hline, vline, rectangle, polygon (default)
+ :param str label: Associated text for identifying draw signals
+ :param color: The color to use to draw the selection area
+ :type color: string ("#RRGGBB") or 4 column unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ """
+ _logger.warning(
+ 'setDrawModeEnabled deprecated, use setInteractiveMode instead')
+
+ if kwargs:
+ _logger.warning('setDrawModeEnabled ignores additional parameters')
+
+ if color is None:
+ color = 'black'
+
+ if flag:
+ self.setInteractiveMode('draw', shape=shape,
+ label=label, color=color)
+ elif self.getInteractiveMode()['mode'] == 'draw':
+ self.setInteractiveMode('select')
+
+ def getDrawMode(self):
+ """Deprecated, use :meth:`getInteractiveMode` instead.
+
+ Return the draw mode parameters as a dict of None.
+
+ It returns None if the interactive mode is not a drawing mode,
+ otherwise, it returns a dict containing the drawing mode parameters
+ as provided to :meth:`setDrawModeEnabled`.
+ """
+ _logger.warning(
+ 'getDrawMode deprecated, use getInteractiveMode instead')
+ mode = self.getInteractiveMode()
+ return mode if mode['mode'] == 'draw' else None
+
+ def isZoomModeEnabled(self):
+ """Deprecated, use :meth:`getInteractiveMode` instead.
+
+ Return True if the current interactive state is zooming."""
+ _logger.warning(
+ 'isZoomModeEnabled deprecated, use getInteractiveMode instead')
+ return self.getInteractiveMode()['mode'] == 'zoom'
+
+ def setZoomModeEnabled(self, flag=True, color=None):
+ """Deprecated, use :meth:`setInteractiveMode` instead.
+
+ Set the zoom mode if flag is True, else item selection is enabled.
+
+ Warning: Zoom and drawing are not compatible and cannot be enabled
+ simultaneously
+
+ :param bool flag: If True, enable zoom and select mode.
+ :param color: The color to use to draw the selection area.
+ (Default: 'black')
+ :param color: The color to use to draw the selection area
+ :type color: string ("#RRGGBB") or 4 column unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ """
+ _logger.warning(
+ 'setZoomModeEnabled deprecated, use setInteractiveMode instead')
+ if color is None:
+ color = 'black'
+
+ if flag:
+ self.setInteractiveMode('zoom', color=color)
+ elif self.getInteractiveMode()['mode'] == 'zoom':
+ self.setInteractiveMode('select')
+
+ def insertMarker(self, *args, **kwargs):
+ """Deprecated, use :meth:`addMarker` instead."""
+ _logger.warning(
+ 'insertMarker deprecated, use addMarker instead.')
+ return self.addMarker(*args, **kwargs)
+
+ def insertXMarker(self, *args, **kwargs):
+ """Deprecated, use :meth:`addXMarker` instead."""
+ _logger.warning(
+ 'insertXMarker deprecated, use addXMarker instead.')
+ return self.addXMarker(*args, **kwargs)
+
+ def insertYMarker(self, *args, **kwargs):
+ """Deprecated, use :meth:`addYMarker` instead."""
+ _logger.warning(
+ 'insertYMarker deprecated, use addYMarker instead.')
+ return self.addYMarker(*args, **kwargs)
+
+ def isActiveCurveHandlingEnabled(self):
+ """Deprecated, use :meth:`isActiveCurveHandling` instead."""
+ _logger.warning(
+ 'isActiveCurveHandlingEnabled deprecated, '
+ 'use isActiveCurveHandling instead.')
+ return self.isActiveCurveHandling()
+
+ def enableActiveCurveHandling(self, *args, **kwargs):
+ """Deprecated, use :meth:`setActiveCurveHandling` instead."""
+ _logger.warning(
+ 'enableActiveCurveHandling deprecated, '
+ 'use setActiveCurveHandling instead.')
+ return self.setActiveCurveHandling(*args, **kwargs)
+
+ def invertYAxis(self, *args, **kwargs):
+ """Deprecated, use :meth:`setYAxisInverted` instead."""
+ _logger.warning('invertYAxis deprecated, '
+ 'use setYAxisInverted instead.')
+ return self.setYAxisInverted(*args, **kwargs)
+
+ def showGrid(self, flag=True):
+ """Deprecated, use :meth:`setGraphGrid` instead."""
+ _logger.warning("showGrid deprecated, use setGraphGrid instead")
+ if flag in (0, False):
+ flag = None
+ elif flag in (1, True):
+ flag = 'major'
+ else:
+ flag = 'both'
+ return self.setGraphGrid(flag)
+
+ def keepDataAspectRatio(self, *args, **kwargs):
+ """Deprecated, use :meth:`setKeepDataAspectRatio`."""
+ _logger.warning('keepDataAspectRatio deprecated,'
+ 'use setKeepDataAspectRatio instead')
+ return self.setKeepDataAspectRatio(*args, **kwargs)
diff --git a/silx/gui/plot/PlotActions.py b/silx/gui/plot/PlotActions.py
new file mode 100644
index 0000000..aad27d2
--- /dev/null
+++ b/silx/gui/plot/PlotActions.py
@@ -0,0 +1,1386 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a set of QAction to use with :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`ColormapAction`
+- :class:`CopyAction`
+- :class:`CrosshairAction`
+- :class:`CurveStyleAction`
+- :class:`FitAction`
+- :class:`GridAction`
+- :class:`KeepAspectRatioAction`
+- :class:`PanWithArrowKeysAction`
+- :class:`PrintAction`
+- :class:`ResetZoomAction`
+- :class:`SaveAction`
+- :class:`XAxisLogarithmicAction`
+- :class:`XAxisAutoScaleAction`
+- :class:`YAxisInvertedAction`
+- :class:`YAxisLogarithmicAction`
+- :class:`YAxisAutoScaleAction`
+- :class:`ZoomInAction`
+- :class:`ZoomOutAction`
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/04/2017"
+
+
+from collections import OrderedDict
+import logging
+import sys
+import traceback
+import weakref
+
+if sys.version_info[0] == 3:
+ from io import BytesIO
+else:
+ import cStringIO as _StringIO
+ BytesIO = _StringIO.StringIO
+
+import numpy
+
+from .. import icons
+from .. import qt
+from .._utils import convertArrayToQImage
+from . import Colors, items
+from .ColormapDialog import ColormapDialog
+from ._utils import applyZoomToPlot as _applyZoomToPlot
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+from silx.math.histogram import Histogramnd
+from silx.math.medianfilter import medfilt2d
+from silx.gui.widgets.MedianFilterDialog import MedianFilterDialog
+
+from silx.io.utils import save1D, savespec
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotAction(qt.QAction):
+ """Base class for QAction that operates on a PlotWidget.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param icon: QIcon or str name of icon to use
+ :param str text: The name of this action to be used for menu label
+ :param str tooltip: The text of the tooltip
+ :param triggered: The callback to connect to the action's triggered
+ signal or None for no callback.
+ :param bool checkable: True for checkable action, False otherwise (default)
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(self, plot, icon, text, tooltip=None,
+ triggered=None, checkable=False, parent=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+
+ if not isinstance(icon, qt.QIcon):
+ # Try with icon as a string and load corresponding icon
+ icon = icons.getQIcon(icon)
+
+ super(PlotAction, self).__init__(icon, text, parent)
+
+ if tooltip is not None:
+ self.setToolTip(tooltip)
+
+ self.setCheckable(checkable)
+
+ if triggered is not None:
+ self.triggered[bool].connect(triggered)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWidget` this action group is controlling."""
+ return self._plotRef()
+
+
+class ResetZoomAction(PlotAction):
+ """QAction controlling reset zoom on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ResetZoomAction, self).__init__(
+ plot, icon='zoom-original', text='Reset Zoom',
+ tooltip='Auto-scale the graph',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self._autoscaleChanged(True)
+ plot.sigSetXAxisAutoScale.connect(self._autoscaleChanged)
+ plot.sigSetYAxisAutoScale.connect(self._autoscaleChanged)
+
+ def _autoscaleChanged(self, enabled):
+ self.setEnabled(
+ self.plot.isXAxisAutoScale() or self.plot.isYAxisAutoScale())
+
+ if self.plot.isXAxisAutoScale() and self.plot.isYAxisAutoScale():
+ tooltip = 'Auto-scale the graph'
+ elif self.plot.isXAxisAutoScale(): # And not Y axis
+ tooltip = 'Auto-scale the x-axis of the graph only'
+ elif self.plot.isYAxisAutoScale(): # And not X axis
+ tooltip = 'Auto-scale the y-axis of the graph only'
+ else: # no axis in autoscale
+ tooltip = 'Auto-scale the graph'
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.resetZoom()
+
+
+class ZoomInAction(PlotAction):
+ """QAction performing a zoom-in on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomInAction, self).__init__(
+ plot, icon='zoom-in', text='Zoom In',
+ tooltip='Zoom in the plot',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.ZoomIn)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1.1)
+
+
+class ZoomOutAction(PlotAction):
+ """QAction performing a zoom-out on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomOutAction, self).__init__(
+ plot, icon='zoom-out', text='Zoom Out',
+ tooltip='Zoom out the plot',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.ZoomOut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1. / 1.1)
+
+
+class XAxisAutoScaleAction(PlotAction):
+ """QAction controlling X axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisAutoScaleAction, self).__init__(
+ plot, icon='plot-xauto', text='X Autoscale',
+ tooltip='Enable x-axis auto-scale when checked.\n'
+ 'If unchecked, x-axis does not change when reseting zoom.',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isXAxisAutoScale())
+ plot.sigSetXAxisAutoScale.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setXAxisAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class YAxisAutoScaleAction(PlotAction):
+ """QAction controlling Y axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisAutoScaleAction, self).__init__(
+ plot, icon='plot-yauto', text='Y Autoscale',
+ tooltip='Enable y-axis auto-scale when checked.\n'
+ 'If unchecked, y-axis does not change when reseting zoom.',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isXAxisAutoScale())
+ plot.sigSetYAxisAutoScale.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setYAxisAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class XAxisLogarithmicAction(PlotAction):
+ """QAction controlling X axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisLogarithmicAction, self).__init__(
+ plot, icon='plot-xlog', text='X Log. scale',
+ tooltip='Logarithmic x-axis when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isXAxisLogarithmic())
+ plot.sigSetXAxisLogarithmic.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setXAxisLogarithmic(checked)
+
+
+class YAxisLogarithmicAction(PlotAction):
+ """QAction controlling Y axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisLogarithmicAction, self).__init__(
+ plot, icon='plot-ylog', text='Y Log. scale',
+ tooltip='Logarithmic y-axis when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isYAxisLogarithmic())
+ plot.sigSetYAxisLogarithmic.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setYAxisLogarithmic(checked)
+
+
+class GridAction(PlotAction):
+ """QAction controlling grid mode on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str gridMode: The grid mode to use in 'both', 'major'.
+ See :meth:`.PlotWidget.setGraphGrid`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, gridMode='both', parent=None):
+ assert gridMode in ('both', 'major')
+ self._gridMode = gridMode
+
+ super(GridAction, self).__init__(
+ plot, icon='plot-grid', text='Grid',
+ tooltip='Toggle grid (on/off)',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getGraphGrid() is not None)
+ plot.sigSetGraphGrid.connect(self._gridChanged)
+
+ def _gridChanged(self, which):
+ """Slot listening for PlotWidget grid mode change."""
+ self.setChecked(which != 'None')
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphGrid(self._gridMode if checked else None)
+
+
+class CurveStyleAction(PlotAction):
+ """QAction controlling curve style on a :class:`.PlotWidget`.
+
+ It changes the default line and markers style which updates all
+ curves on the plot.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CurveStyleAction, self).__init__(
+ plot, icon='plot-toggle-points', text='Curve style',
+ tooltip='Change curve line and markers style',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+
+ def _actionTriggered(self, checked=False):
+ currentState = (self.plot.isDefaultPlotLines(),
+ self.plot.isDefaultPlotPoints())
+
+ # line only, line and symbol, symbol only
+ states = (True, False), (True, True), (False, True)
+ newState = states[(states.index(currentState) + 1) % 3]
+
+ self.plot.setDefaultPlotLines(newState[0])
+ self.plot.setDefaultPlotPoints(newState[1])
+
+
+class ColormapAction(PlotAction):
+ """QAction opening a ColormapDialog to update the colormap.
+
+ Both the active image colormap and the default colormap are updated.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ self._dialog = None # To store an instance of ColormapDialog
+ super(ColormapAction, self).__init__(
+ plot, icon='colormap', text='Colormap',
+ tooltip="Change colormap",
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+
+ def _actionTriggered(self, checked=False):
+ """Create a cmap dialog and update active image and default cmap."""
+ # Create the dialog if not already existing
+ if self._dialog is None:
+ self._dialog = ColormapDialog()
+
+ image = self.plot.getActiveImage()
+ if not isinstance(image, items.ColormapMixIn):
+ # No active image or active image is RGBA,
+ # set dialog from default info
+ colormap = self.plot.getDefaultColormap()
+
+ self._dialog.setHistogram() # Reset histogram and range if any
+
+ else:
+ # Set dialog from active image
+ colormap = image.getColormap()
+
+ data = image.getData(copy=False)
+
+ goodData = data[numpy.isfinite(data)]
+ if goodData.size > 0:
+ dataMin = goodData.min()
+ dataMax = goodData.max()
+ else:
+ qt.QMessageBox.warning(
+ self, "No Data",
+ "Image data does not contain any real value")
+ dataMin, dataMax = 1., 10.
+
+ self._dialog.setHistogram() # Reset histogram if any
+ self._dialog.setDataRange(dataMin, dataMax)
+ # The histogram should be done in a worker thread
+ # hist, bin_edges = numpy.histogram(goodData, bins=256)
+ # self._dialog.setHistogram(hist, bin_edges)
+
+ self._dialog.setColormap(**colormap)
+
+ # Run the dialog listening to colormap change
+ self._dialog.sigColormapChanged.connect(self._colormapChanged)
+ result = self._dialog.exec_()
+ self._dialog.sigColormapChanged.disconnect(self._colormapChanged)
+
+ if not result: # Restore the previous colormap
+ self._colormapChanged(colormap)
+
+ def _colormapChanged(self, colormap):
+ # Update default colormap
+ self.plot.setDefaultColormap(colormap)
+
+ # Update active image colormap
+ activeImage = self.plot.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(colormap)
+
+
+class KeepAspectRatioAction(PlotAction):
+ """QAction controlling aspect ratio on a :class:`.PlotWidget`.
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon('shape-circle-solid'),
+ "Keep data aspect ratio"),
+ True: (icons.getQIcon('shape-ellipse-solid'),
+ "Do no keep data aspect ratio")
+ }
+
+ icon, tooltip = self._states[plot.isKeepDataAspectRatio()]
+ super(KeepAspectRatioAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Toggle keep aspect ratio',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent)
+ plot.sigSetKeepDataAspectRatio.connect(
+ self._keepDataAspectRatioChanged)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, tooltip = self._states[aspectRatio]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _keepDataAspectRatioChanged
+ self.plot.setKeepDataAspectRatio(not self.plot.isKeepDataAspectRatio())
+
+
+class YAxisInvertedAction(PlotAction):
+ """QAction controlling Y orientation on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon('plot-ydown'),
+ "Orient Y axis downward"),
+ True: (icons.getQIcon('plot-yup'),
+ "Orient Y axis upward"),
+ }
+
+ icon, tooltip = self._states[plot.isYAxisInverted()]
+ super(YAxisInvertedAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Invert Y Axis',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent)
+ plot.sigSetYAxisInverted.connect(self._yAxisInvertedChanged)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ icon, tooltip = self._states[inverted]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _yAxisInvertedChanged
+ self.plot.setYAxisInverted(not self.plot.isYAxisInverted())
+
+
+class SaveAction(PlotAction):
+ """QAction for saving Plot content.
+
+ It opens a Save as... dialog.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+ # TODO find a way to make the filter list selectable and extensible
+
+ SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)'
+
+ SNAPSHOT_FILTERS = ('Plot Snapshot as PNG (*.png)',
+ 'Plot Snapshot as JPEG (*.jpg)',
+ SNAPSHOT_FILTER_SVG)
+
+ # Dict of curve filters with CSV-like format
+ # Using ordered dict to guarantee filters order
+ # Note: '%.18e' is numpy.savetxt default format
+ CURVE_FILTERS_TXT = OrderedDict((
+ ('Curve as Raw ASCII (*.txt)',
+ {'fmt': '%.18e', 'delimiter': ' ', 'header': False}),
+ ('Curve as ";"-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ';', 'header': True}),
+ ('Curve as ","-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ',', 'header': True}),
+ ('Curve as tab-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': '\t', 'header': True}),
+ ('Curve as OMNIC CSV (*.csv)',
+ {'fmt': '%.7E', 'delimiter': ',', 'header': False}),
+ ('Curve as SpecFile (*.dat)',
+ {'fmt': '%.7g', 'delimiter': '', 'header': False})
+ ))
+
+ CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)'
+
+ CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [CURVE_FILTER_NPY]
+
+ ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)", )
+
+ IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)'
+ IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)'
+ IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)'
+ IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)'
+ IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)'
+ IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)'
+ IMAGE_FILTER_RGB_TIFF = 'Image as TIFF (*.tif)'
+ IMAGE_FILTERS = (IMAGE_FILTER_EDF,
+ IMAGE_FILTER_TIFF,
+ IMAGE_FILTER_NUMPY,
+ IMAGE_FILTER_ASCII,
+ IMAGE_FILTER_CSV_COMMA,
+ IMAGE_FILTER_CSV_SEMICOLON,
+ IMAGE_FILTER_CSV_TAB,
+ IMAGE_FILTER_RGB_PNG,
+ IMAGE_FILTER_RGB_TIFF)
+
+ def __init__(self, plot, parent=None):
+ super(SaveAction, self).__init__(
+ plot, icon='document-save', text='Save as...',
+ tooltip='Save curve/image/plot snapshot dialog',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _errorMessage(self, informativeText=''):
+ """Display an error message."""
+ # TODO issue with QMessageBox size fixed and too small
+ msg = qt.QMessageBox(self.plot)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1]))
+ msg.setDetailedText(traceback.format_exc())
+ msg.exec_()
+
+ def _saveSnapshot(self, filename, nameFilter):
+ """Save a snapshot of the :class:`PlotWindow` widget.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter == self.SNAPSHOT_FILTER_SVG:
+ self.plot.saveGraph(filename, fileFormat='svg')
+
+ else:
+ if hasattr(qt.QPixmap, "grabWidget"):
+ # Qt 4
+ pixmap = qt.QPixmap.grabWidget(self.plot.getWidgetHandle())
+ else:
+ # Qt 5
+ pixmap = self.plot.getWidgetHandle().grab()
+ if not pixmap.save(filename):
+ self._errorMessage()
+ return False
+ return True
+
+ def _saveCurve(self, filename, nameFilter):
+ """Save a curve from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.CURVE_FILTERS:
+ return False
+
+ # Check if a curve is to be saved
+ curve = self.plot.getActiveCurve()
+ # before calling _saveCurve, if there is no selected curve, we
+ # make sure there is only one curve on the graph
+ if curve is None:
+ curves = self.plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curve to be saved")
+ return False
+ curve = curves[0]
+
+ if nameFilter in self.CURVE_FILTERS_TXT:
+ filter_ = self.CURVE_FILTERS_TXT[nameFilter]
+ fmt = filter_['fmt']
+ csvdelim = filter_['delimiter']
+ autoheader = filter_['header']
+ else:
+ # .npy
+ fmt, csvdelim, autoheader = ("", "", False)
+
+ # If curve has no associated label, get the default from the plot
+ xlabel = curve.getXLabel()
+ if xlabel is None:
+ xlabel = self.plot.getGraphXLabel()
+ ylabel = curve.getYLabel()
+ if ylabel is None:
+ ylabel = self.plot.getGraphYLabel()
+
+ try:
+ save1D(filename,
+ curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ xlabel, [ylabel],
+ fmt=fmt, csvdelim=csvdelim,
+ autoheader=autoheader)
+ except IOError:
+ self._errorMessage('Save failed\n')
+ return False
+
+ return True
+
+ def _saveCurves(self, filename, nameFilter):
+ """Save all curves from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.ALL_CURVES_FILTERS:
+ return False
+
+ curves = self.plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curves to be saved")
+ return False
+
+ curve = curves[0]
+ scanno = 1
+ try:
+ specfile = savespec(filename,
+ curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ curve.getXLabel(),
+ curve.getYLabel(),
+ fmt="%.7g", scan_number=1, mode="w",
+ write_file_header=True,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n')
+ return False
+
+ for curve in curves[1:]:
+ try:
+ scanno += 1
+ specfile = savespec(specfile,
+ curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ curve.getXLabel(),
+ curve.getYLabel(),
+ fmt="%.7g", scan_number=scanno, mode="w",
+ write_file_header=False,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n')
+ return False
+ specfile.close()
+
+ return True
+
+ def _saveImage(self, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.IMAGE_FILTERS:
+ return False
+
+ image = self.plot.getActiveImage()
+ if image is None:
+ qt.QMessageBox.warning(
+ self.plot, "No Data", "No image to be saved")
+ return False
+
+ data = image.getData(copy=False)
+
+ # TODO Use silx.io for writing files
+ if nameFilter == self.IMAGE_FILTER_EDF:
+ edfFile = EdfFile(filename, access="w+")
+ edfFile.WriteImage({}, data, Append=0)
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_TIFF:
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(data, software='silx')
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NUMPY:
+ try:
+ numpy.save(filename, data)
+ except IOError:
+ self._errorMessage('Save failed\n')
+ return False
+ return True
+
+ elif nameFilter in (self.IMAGE_FILTER_ASCII,
+ self.IMAGE_FILTER_CSV_COMMA,
+ self.IMAGE_FILTER_CSV_SEMICOLON,
+ self.IMAGE_FILTER_CSV_TAB):
+ csvdelim, filetype = {
+ self.IMAGE_FILTER_ASCII: (' ', 'txt'),
+ self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'),
+ self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'),
+ self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'),
+ }[nameFilter]
+
+ height, width = data.shape
+ rows, cols = numpy.mgrid[0:height, 0:width]
+ try:
+ save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()),
+ filetype=filetype,
+ xlabel='row',
+ ylabels=['column', 'value'],
+ csvdelim=csvdelim,
+ autoheader=True)
+
+ except IOError:
+ self._errorMessage('Save failed\n')
+ return False
+ return True
+
+ elif nameFilter in (self.IMAGE_FILTER_RGB_PNG,
+ self.IMAGE_FILTER_RGB_TIFF):
+ # Get displayed image
+ rgbaImage = image.getRbgaImageData(copy=False)
+ # Convert RGB QImage
+ qimage = convertArrayToQImage(rgbaImage[:, :, :3])
+
+ if nameFilter == self.IMAGE_FILTER_RGB_PNG:
+ fileFormat = 'PNG'
+ else:
+ fileFormat = 'TIFF'
+
+ if qimage.save(filename, fileFormat):
+ return True
+ else:
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save image as',
+ 'Failed to save image')
+
+ return False
+
+ def _actionTriggered(self, checked=False):
+ """Handle save action."""
+ # Set-up filters
+ filters = []
+
+ # Add image filters if there is an active image
+ if self.plot.getActiveImage() is not None:
+ filters.extend(self.IMAGE_FILTERS)
+
+ # Add curve filters if there is a curve to save
+ if (self.plot.getActiveCurve() is not None or
+ len(self.plot.getAllCurves()) == 1):
+ filters.extend(self.CURVE_FILTERS)
+ if len(self.plot.getAllCurves()) > 1:
+ filters.extend(self.ALL_CURVES_FILTERS)
+
+ filters.extend(self.SNAPSHOT_FILTERS)
+
+ # Create and run File dialog
+ dialog = qt.QFileDialog(self.plot)
+ dialog.setWindowTitle("Output File Selection")
+ dialog.setModal(1)
+ dialog.setNameFilters(filters)
+
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ if not dialog.exec_():
+ return False
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ # Handle save
+ if nameFilter in self.SNAPSHOT_FILTERS:
+ return self._saveSnapshot(filename, nameFilter)
+ elif nameFilter in self.CURVE_FILTERS:
+ return self._saveCurve(filename, nameFilter)
+ elif nameFilter in self.ALL_CURVES_FILTERS:
+ return self._saveCurves(filename, nameFilter)
+ elif nameFilter in self.IMAGE_FILTERS:
+ return self._saveImage(filename, nameFilter)
+ else:
+ _logger.warning('Unsupported file filter: %s', nameFilter)
+ return False
+
+
+def _plotAsPNG(plot):
+ """Save a :class:`Plot` as PNG and return the payload.
+
+ :param plot: The :class:`Plot` to save
+ """
+ pngFile = BytesIO()
+ plot.saveGraph(pngFile, fileFormat='png')
+ pngFile.flush()
+ pngFile.seek(0)
+ data = pngFile.read()
+ pngFile.close()
+ return data
+
+
+class PrintAction(PlotAction):
+ """QAction for printing the plot.
+
+ It opens a Print dialog.
+
+ Current implementation print a bitmap of the plot area and not vector
+ graphics, so printing quality is not great.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ # Share QPrinter instance to propose latest used as default
+ _printer = None
+
+ def __init__(self, plot, parent=None):
+ super(PrintAction, self).__init__(
+ plot, icon='document-print', text='Print...',
+ tooltip='Open print dialog',
+ triggered=self.printPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ @property
+ def printer(self):
+ """The QPrinter instance used by the actions.
+
+ This is shared accross all instances of PrintAct
+ """
+ if self._printer is None:
+ PrintAction._printer = qt.QPrinter()
+ return self._printer
+
+ def printPlotAsWidget(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`QWidget.render` to print the plot
+
+ :return: True if successful
+ """
+ dialog = qt.QPrintDialog(self.printer, self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec_():
+ return False
+
+ # Print a snapshot of the plot widget at the top of the page
+ widget = self.plot.centralWidget()
+
+ painter = qt.QPainter()
+ if not painter.begin(self.printer):
+ return False
+
+ pageRect = self.printer.pageRect()
+ xScale = pageRect.width() / widget.width()
+ yScale = pageRect.height() / widget.height()
+ scale = min(xScale, yScale)
+
+ painter.translate(pageRect.width() / 2., 0.)
+ painter.scale(scale, scale)
+ painter.translate(-widget.width() / 2., 0.)
+ widget.render(painter)
+ painter.end()
+
+ return True
+
+ def printPlot(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`Plot.saveGraph` to print the plot.
+
+ :return: True if successful
+ """
+ # Init printer and start printer dialog
+ dialog = qt.QPrintDialog(self.printer, self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec_():
+ return False
+
+ # Save Plot as PNG and make a pixmap from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+
+ pixmap = qt.QPixmap()
+ pixmap.loadFromData(pngData, 'png')
+
+ xScale = self.printer.pageRect().width() / pixmap.width()
+ yScale = self.printer.pageRect().height() / pixmap.height()
+ scale = min(xScale, yScale)
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(self.printer):
+ return False
+
+ painter.drawPixmap(0, 0,
+ pixmap.width() * scale,
+ pixmap.height() * scale,
+ pixmap)
+ painter.end()
+
+ return True
+
+
+class CopyAction(PlotAction):
+ """QAction to copy :class:`.PlotWidget` content to clipboard.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CopyAction, self).__init__(
+ plot, icon='edit-copy', text='Copy plot',
+ tooltip='Copy a snapshot of the plot into the clipboard',
+ triggered=self.copyPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def copyPlot(self):
+ """Copy plot content to the clipboard as a bitmap."""
+ # Save Plot as PNG and make a QImage from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+ image = qt.QImage.fromData(pngData, 'png')
+ qt.QApplication.clipboard().setImage(image)
+
+
+class CrosshairAction(PlotAction):
+ """QAction toggling crosshair cursor on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str color: Color to use to draw the crosshair
+ :param int linewidth: Width of the crosshair cursor
+ :param str linestyle: Style of line. See :meth:`.Plot.setGraphCursor`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, color='black', linewidth=1, linestyle='-',
+ parent=None):
+ self.color = color
+ """Color used to draw the crosshair (str)."""
+
+ self.linewidth = linewidth
+ """Width of the crosshair cursor (int)."""
+
+ self.linestyle = linestyle
+ """Style of line of the cursor (str)."""
+
+ super(CrosshairAction, self).__init__(
+ plot, icon='crosshair', text='Crosshair Cursor',
+ tooltip='Enable crosshair cursor when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getGraphCursor() is not None)
+ plot.sigSetGraphCursor.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphCursor(checked,
+ color=self.color,
+ linestyle=self.linestyle,
+ linewidth=self.linewidth)
+
+
+class PanWithArrowKeysAction(PlotAction):
+ """QAction toggling pan with arrow keys on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+
+ super(PanWithArrowKeysAction, self).__init__(
+ plot, icon='arrow-keys', text='Pan with arrow keys',
+ tooltip='Enable pan with arrow keys when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isPanWithArrowKeys())
+ plot.sigSetPanWithArrowKeys.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setPanWithArrowKeys(checked)
+
+
+def _warningMessage(informativeText='', detailedText='', parent=None):
+ """Display a popup warning message."""
+ msg = qt.QMessageBox(parent)
+ msg.setIcon(qt.QMessageBox.Warning)
+ msg.setInformativeText(informativeText)
+ msg.setDetailedText(detailedText)
+ msg.exec_()
+
+
+def _getOneCurve(plt, mode="unique"):
+ """Get a single curve from the plot.
+ By default, get the active curve if any, else if a single curve is plotted
+ get it, else return None and display a warning popup.
+
+ This behavior can be adjusted by modifying the *mode* parameter: always
+ return the active curve if any, but adjust the behavior in case no curve
+ is active.
+
+ :param plt: :class:`.PlotWidget` instance on which to operate
+ :param mode: Parameter defining the behavior when no curve is active.
+ Possible modes:
+ - "none": return None (enforce curve activation)
+ - "unique": return the unique curve or None if multiple curves
+ - "first": return first curve
+ - "last": return last curve (most recently added one)
+ :return: return value of plt.getActiveCurve(), or plt.getAllCurves()[0],
+ or plt.getAllCurves()[-1], or None
+ """
+ curve = plt.getActiveCurve()
+ if curve is not None:
+ return curve
+
+ if mode is None or mode.lower() == "none":
+ _warningMessage("You must activate a curve!",
+ parent=plt)
+ return None
+
+ curves = plt.getAllCurves()
+ if len(curves) == 0:
+ _warningMessage("No curve on this plot.",
+ parent=plt)
+ return None
+
+ if len(curves) == 1:
+ return curves[0]
+
+ if len(curves) > 1:
+ if mode == "unique":
+ _warningMessage("Multiple curves are plotted. " +
+ "Please activate the one you want to use.",
+ parent=plt)
+ return None
+ if mode.lower() == "first":
+ return curves[0]
+ if mode.lower() == "last":
+ return curves[-1]
+
+ raise ValueError("Illegal value for parameter 'mode'." +
+ " Allowed values: 'none', 'unique', 'first', 'last'.")
+
+
+class FitAction(PlotAction):
+ """QAction to open a :class:`FitWidget` and set its data to the
+ active curve if any, or to the first curve.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ super(FitAction, self).__init__(
+ plot, icon='math-fit', text='Fit curve',
+ tooltip='Open a fit dialog',
+ triggered=self._getFitWindow,
+ checkable=False, parent=parent)
+ self.fit_window = None
+
+ def _getFitWindow(self):
+ curve = _getOneCurve(self.plot)
+ if curve is None:
+ return
+ self.xlabel = self.plot.getGraphXLabel()
+ self.ylabel = self.plot.getGraphYLabel()
+ self.x = curve.getXData(copy=False)
+ self.y = curve.getYData(copy=False)
+ self.legend = curve.getLegend()
+ self.xmin, self.xmax = self.plot.getGraphXLimits()
+
+ # open a window with a FitWidget
+ if self.fit_window is None:
+ self.fit_window = qt.QMainWindow()
+ # import done here rather than at module level to avoid circular import
+ # FitWidget -> BackgroundWidget -> PlotWindow -> PlotActions -> FitWidget
+ from ..fit.FitWidget import FitWidget
+ self.fit_widget = FitWidget(parent=self.fit_window)
+ self.fit_window.setCentralWidget(
+ self.fit_widget)
+ self.fit_widget.guibuttons.DismissButton.clicked.connect(
+ self.fit_window.close)
+ self.fit_widget.sigFitWidgetSignal.connect(
+ self.handle_signal)
+ self.fit_window.show()
+ else:
+ if self.fit_window.isHidden():
+ self.fit_window.show()
+ self.fit_widget.show()
+ self.fit_window.raise_()
+
+ self.fit_widget.setData(self.x, self.y,
+ xmin=self.xmin, xmax=self.xmax)
+ self.fit_window.setWindowTitle(
+ "Fitting " + self.legend +
+ " on x range %f-%f" % (self.xmin, self.xmax))
+
+ def handle_signal(self, ddict):
+ x_fit = self.x[self.xmin <= self.x]
+ x_fit = x_fit[x_fit <= self.xmax]
+ fit_legend = "Fit <%s>" % self.legend
+ fit_curve = self.plot.getCurve(fit_legend)
+
+ if ddict["event"] == "FitFinished":
+ y_fit = self.fit_widget.fitmanager.gendata()
+ if fit_curve is None:
+ self.plot.addCurve(x_fit, y_fit,
+ fit_legend,
+ xlabel=self.xlabel, ylabel=self.ylabel,
+ resetzoom=False)
+ else:
+ fit_curve.setData(x_fit, y_fit)
+ fit_curve.setVisible(True)
+
+ if ddict["event"] in ["FitStarted", "FitFailed"]:
+ if fit_curve is not None:
+ fit_curve.setVisible(False)
+
+
+class PixelIntensitiesHistoAction(PlotAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotAction.__init__(self,
+ plot,
+ icon='pixel-intensities',
+ text='pixels intensity',
+ tooltip='Compute image intensity distribution',
+ triggered=self._triggered,
+ parent=parent,
+ checkable=True)
+ self._plotHistogram = None
+ self._connectedToActiveImage = False
+ self._histo = None
+
+ def _triggered(self, checked):
+ """Update the plot of the histogram visibility status
+
+ :param bool checked: status of the action button
+ """
+ if checked:
+ if not self._connectedToActiveImage:
+ self.plot.sigActiveImageChanged.connect(
+ self._activeImageChanged)
+ self._connectedToActiveImage = True
+ self.computeIntensityDistribution()
+
+ self.getHistogramPlotWidget().show()
+
+ else:
+ if self._connectedToActiveImage:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ self._connectedToActiveImage = False
+
+ self.getHistogramPlotWidget().hide()
+
+ def _activeImageChanged(self, previous, legend):
+ """Handle active image change: toggle enabled toolbar, update curve"""
+ if self.isChecked():
+ self.computeIntensityDistribution()
+
+ def computeIntensityDistribution(self):
+ """Get the active image and compute the image intensity distribution
+ """
+ activeImage = self.plot.getActiveImage()
+
+ if activeImage is not None:
+ image = activeImage.getData(copy=False)
+ if image.ndim == 3: # RGB(A) images
+ _logger.info('Converting current image from RGB(A) to grayscale\
+ in order to compute the intensity distribution')
+ image = (image[:, :, 0] * 0.299 +
+ image[:, :, 1] * 0.587 +
+ image[:, :, 2] * 0.114)
+
+ xmin = numpy.nanmin(image)
+ xmax = numpy.nanmax(image)
+ nbins = min(1024, int(numpy.sqrt(image.size)))
+ data_range = xmin, xmax
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(image.dtype, numpy.integer):
+ if nbins > xmax - xmin:
+ nbins = xmax - xmin
+
+ nbins = max(2, nbins)
+
+ data = image.ravel().astype(numpy.float32)
+ histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
+ assert len(histogram.edges) == 1
+ self._histo = histogram.histo
+ edges = histogram.edges[0]
+ plot = self.getHistogramPlotWidget()
+ plot.addHistogram(histogram=self._histo,
+ edges=edges,
+ legend='pixel intensity',
+ fill=True,
+ color='red')
+ plot.resetZoom()
+
+ def eventFilter(self, qobject, event):
+ """Observe when the close event is emitted then
+ simply uncheck the action button
+
+ :param qobject: the object observe
+ :param event: the event received by qobject
+ """
+ if event.type() == qt.QEvent.Close:
+ if self._plotHistogram is not None:
+ self._plotHistogram.hide()
+ self.setChecked(False)
+
+ return PlotAction.eventFilter(self, qobject, event)
+
+ def getHistogramPlotWidget(self):
+ """Create the plot histogram if needed, otherwise create it
+
+ :return: the PlotWidget showing the histogram of the pixel intensities
+ """
+ from silx.gui.plot.PlotWindow import Plot1D
+ if self._plotHistogram is None:
+ self._plotHistogram = Plot1D(parent=self.plot)
+ self._plotHistogram.setWindowFlags(qt.Qt.Window)
+ self._plotHistogram.setWindowTitle('Image Intensity Histogram')
+ self._plotHistogram.installEventFilter(self)
+ self._plotHistogram.setGraphXLabel("Value")
+ self._plotHistogram.setGraphYLabel("Count")
+
+ return self._plotHistogram
+
+ def getHistogram(self):
+ """Return the last computed histogram
+
+ :return: the histogram displayed in the HistogramPlotWiget
+ """
+ return self._histo
+
+
+class MedianFilterAction(PlotAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotAction.__init__(self,
+ plot,
+ icon='median-filter',
+ text='median filter',
+ tooltip='Apply a median filter on the image',
+ triggered=self._triggered,
+ parent=parent)
+ self._originalImage = None
+ self._legend = None
+ self._filteredImage = None
+ self._popup = MedianFilterDialog(parent=None)
+ self._popup.sigFilterOptChanged.connect(self._updateFilter)
+ self.plot.sigActiveImageChanged.connect( self._updateActiveImage)
+ self._updateActiveImage()
+
+ def _triggered(self, checked):
+ """Update the plot of the histogram visibility status
+
+ :param bool checked: status of the action button
+ """
+ self._popup.show()
+
+ def _updateActiveImage(self):
+ """Set _activeImageLegend and _originalImage from the active image"""
+ self._activeImageLegend = self.plot.getActiveImage(just_legend=True)
+ if self._activeImageLegend is None:
+ self._originalImage = None
+ self._legend = None
+ else:
+ self._originalImage = self.plot.getImage(self._activeImageLegend).getData(copy=False)
+ self._legend = self.plot.getImage(self._activeImageLegend).getLegend()
+
+ def _updateFilter(self, kernelWidth, conditional=False):
+ if self._originalImage is None:
+ return
+
+ self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
+ filteredImage = self._computeFilteredImage(kernelWidth, conditional)
+ self.plot.addImage(data=filteredImage,
+ legend=self._legend,
+ replace=True)
+ self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ raise NotImplemented('MedianFilterAction is a an abstract class')
+
+ def getFilteredImage(self):
+ """
+ :return: the image with the median filter apply on"""
+ return self._filteredImage
+
+
+class MedianFilter1DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 1D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self,
+ plot,
+ parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert(self.plot is not None)
+ return medfilt2d(self._originalImage,
+ (kernelWidth, 1),
+ conditional)
+
+
+class MedianFilter2DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 2D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self,
+ plot,
+ parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert(self.plot is not None)
+ return medfilt2d(self._originalImage,
+ (kernelWidth, kernelWidth),
+ conditional)
diff --git a/silx/gui/plot/PlotEvents.py b/silx/gui/plot/PlotEvents.py
new file mode 100644
index 0000000..83f253c
--- /dev/null
+++ b/silx/gui/plot/PlotEvents.py
@@ -0,0 +1,166 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to prepare events to be sent to Plot callback."""
+
+__author__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import numpy as np
+
+
+def prepareDrawingSignal(event, type_, points, parameters=None):
+ """See Plot documentation for content of events"""
+ assert event in ('drawingProgress', 'drawingFinished')
+
+ if parameters is None:
+ parameters = {}
+
+ eventDict = {}
+ eventDict['event'] = event
+ eventDict['type'] = type_
+ points = np.array(points, dtype=np.float32)
+ points.shape = -1, 2
+ eventDict['points'] = points
+ eventDict['xdata'] = points[:, 0]
+ eventDict['ydata'] = points[:, 1]
+ if type_ in ('rectangle',):
+ eventDict['x'] = eventDict['xdata'].min()
+ eventDict['y'] = eventDict['ydata'].min()
+ eventDict['width'] = eventDict['xdata'].max() - eventDict['x']
+ eventDict['height'] = eventDict['ydata'].max() - eventDict['y']
+ eventDict['parameters'] = parameters.copy()
+ return eventDict
+
+
+def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ assert eventType in ('mouseMoved', 'mouseClicked', 'mouseDoubleClicked')
+ assert button in (None, 'left', 'middle', 'right')
+
+ return {'event': eventType,
+ 'x': xData,
+ 'y': yData,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel,
+ 'button': button}
+
+
+def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable):
+ """See Plot documentation for content of events"""
+ return {'event': 'hover',
+ 'label': label,
+ 'type': type_,
+ 'x': posData[0],
+ 'y': posData[1],
+ 'xpixel': posPixel[0],
+ 'ypixel': posPixel[1],
+ 'draggable': draggable,
+ 'selectable': selectable}
+
+
+def prepareMarkerSignal(eventType, button, label, type_,
+ draggable, selectable,
+ posDataMarker,
+ posPixelCursor=None, posDataCursor=None):
+ """See Plot documentation for content of events"""
+ if eventType == 'markerClicked':
+ assert posPixelCursor is not None
+ assert posDataCursor is None
+
+ posDataCursor = list(posDataMarker)
+ if hasattr(posDataCursor[0], "__len__"):
+ posDataCursor[0] = posDataCursor[0][-1]
+ if hasattr(posDataCursor[1], "__len__"):
+ posDataCursor[1] = posDataCursor[1][-1]
+
+ elif eventType == 'markerMoving':
+ assert posPixelCursor is not None
+ assert posDataCursor is not None
+
+ elif eventType == 'markerMoved':
+ assert posPixelCursor is None
+ assert posDataCursor is None
+
+ posDataCursor = posDataMarker
+ else:
+ raise NotImplementedError("Unknown event type {0}".format(eventType))
+
+ eventDict = {'event': eventType,
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'x': posDataCursor[0],
+ 'y': posDataCursor[1],
+ 'xdata': posDataMarker[0],
+ 'ydata': posDataMarker[1],
+ 'draggable': draggable,
+ 'selectable': selectable}
+
+ if eventType in ('markerMoving', 'markerClicked'):
+ eventDict['xpixel'] = posPixelCursor[0]
+ eventDict['ypixel'] = posPixelCursor[1]
+
+ return eventDict
+
+
+def prepareImageSignal(button, label, type_, col, row,
+ x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {'event': 'imageClicked',
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'col': col,
+ 'row': row,
+ 'x': x,
+ 'y': y,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel}
+
+
+def prepareCurveSignal(button, label, type_, xData, yData,
+ x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {'event': 'curveClicked',
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'xdata': xData,
+ 'ydata': yData,
+ 'x': x,
+ 'y': y,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel}
+
+
+def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range):
+ """See Plot documentation for content of events"""
+ return {'event': 'limitsChanged',
+ 'source': id(sourceObj),
+ 'xdata': xRange,
+ 'ydata': yRange,
+ 'y2data': y2Range}
diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py
new file mode 100644
index 0000000..fbc9c1f
--- /dev/null
+++ b/silx/gui/plot/PlotInteraction.py
@@ -0,0 +1,1493 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Implementation of the interaction for the :class:`Plot`."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+import math
+import numpy
+import time
+import weakref
+
+from . import Colors
+from . import items
+from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN,
+ State, StateMachine)
+from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal,
+ prepareHoverSignal, prepareImageSignal,
+ prepareMarkerSignal, prepareMouseSignal)
+
+from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR,
+ CURSOR_SIZE_VER, CURSOR_SIZE_ALL)
+
+from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX,
+ applyZoomToPlot)
+
+
+# Base class ##################################################################
+
+class _PlotInteraction(object):
+ """Base class for interaction handler.
+
+ It provides a weakref to the plot and methods to set/reset overlay.
+ """
+ def __init__(self, plot):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._needReplot = False
+ self._selectionAreas = set()
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ @property
+ def plot(self):
+ plot = self._plot()
+ assert plot is not None
+ return plot
+
+ def setSelectionArea(self, points, fill, color, name='', shape='polygon'):
+ """Set a polygon selection area overlaid on the plot.
+ Multiple simultaneous areas are supported through the name parameter.
+
+ :param points: The 2D coordinates of the points of the polygon
+ :type points: An iterable of (x, y) coordinates
+ :param str fill: The fill mode: 'hatch', 'solid' or 'none'
+ :param color: RGBA color to use or None to disable display
+ :type color: list or tuple of 4 float in the range [0, 1]
+ :param name: The key associated with this selection area
+ :param str shape: Shape of the area in 'polygon', 'polylines'
+ """
+ assert shape in ('polygon', 'polylines')
+
+ if color is None:
+ return
+
+ points = numpy.asarray(points)
+
+ # TODO Not very nice, but as is for now
+ legend = '__SELECTION_AREA__' + name
+
+ fill = fill != 'none' # TODO not very nice either
+
+ self.plot.addItem(points[:, 0], points[:, 1], legend=legend,
+ replace=False,
+ shape=shape, color=color, fill=fill,
+ overlay=True)
+ self._selectionAreas.add(legend)
+
+ def resetSelectionArea(self):
+ """Remove all selection areas set by setSelectionArea."""
+ for legend in self._selectionAreas:
+ self.plot.remove(legend, kind='item')
+ self._selectionAreas = set()
+
+
+# Zoom/Pan ####################################################################
+
+class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
+ """:class:`ClickOrDrag` state machine with zooming on mouse wheel.
+
+ Base class for :class:`Pan` and :class:`Zoom`
+ """
+ class ZoomIdle(ClickOrDrag.Idle):
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def __init__(self, plot):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ _PlotInteraction.__init__(self, plot)
+
+ states = {
+ 'idle': _ZoomOnWheel.ZoomIdle,
+ 'rightClick': ClickOrDrag.RightClick,
+ 'clickOrDrag': ClickOrDrag.ClickOrDrag,
+ 'drag': ClickOrDrag.Drag
+ }
+ StateMachine.__init__(self, states, 'idle')
+
+
+# Pan #########################################################################
+
+class Pan(_ZoomOnWheel):
+ """Pan plot content and zoom on wheel state machine."""
+
+ def _pixelToData(self, x, y):
+ xData, yData = self.plot.pixelToData(x, y)
+ _, y2Data = self.plot.pixelToData(x, y, axis='right')
+ return xData, yData, y2Data
+
+ def beginDrag(self, x, y):
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def drag(self, x, y):
+ xData, yData, y2Data = self._pixelToData(x, y)
+ lastX, lastY, lastY2 = self._previousDataPos
+
+ xMin, xMax = self.plot.getGraphXLimits()
+ yMin, yMax = self.plot.getGraphYLimits(axis='left')
+ y2Min, y2Max = self.plot.getGraphYLimits(axis='right')
+
+ if self.plot.isXAxisLogarithmic():
+ try:
+ dx = math.log10(xData) - math.log10(lastX)
+ newXMin = pow(10., (math.log10(xMin) - dx))
+ newXMax = pow(10., (math.log10(xMax) - dx))
+ except (ValueError, OverflowError):
+ newXMin, newXMax = xMin, xMax
+
+ # Makes sure both values stays in positive float32 range
+ if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+ else:
+ dx = xData - lastX
+ newXMin, newXMax = xMin - dx, xMax - dx
+
+ # Makes sure both values stays in float32 range
+ if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+
+ if self.plot.isYAxisLogarithmic():
+ try:
+ dy = math.log10(yData) - math.log10(lastY)
+ newYMin = pow(10., math.log10(yMin) - dy)
+ newYMax = pow(10., math.log10(yMax) - dy)
+
+ dy2 = math.log10(y2Data) - math.log10(lastY2)
+ newY2Min = pow(10., math.log10(y2Min) - dy2)
+ newY2Max = pow(10., math.log10(y2Max) - dy2)
+ except (ValueError, OverflowError):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ # Makes sure y and y2 stays in positive float32 range
+ if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+ else:
+ dy = yData - lastY
+ dy2 = y2Data - lastY2
+ newYMin, newYMax = yMin - dy, yMax - dy
+ newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
+
+ # Makes sure y and y2 stays in float32 range
+ if (newYMin < FLOAT32_SAFE_MIN or
+ newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_SAFE_MIN or
+ newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ self.plot.setLimits(newXMin, newXMax,
+ newYMin, newYMax,
+ newY2Min, newY2Max)
+
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def endDrag(self, startPos, endPos):
+ del self._previousDataPos
+
+ def cancel(self):
+ pass
+
+
+# Zoom ########################################################################
+
+class Zoom(_ZoomOnWheel):
+ """Zoom-in/out state machine.
+
+ Zoom-in on selected area, zoom-out on right click,
+ and zoom on mouse wheel.
+ """
+ _DOUBLE_CLICK_TIMEOUT = 0.4
+
+ def __init__(self, plot, color):
+ self.color = color
+ self.zoomStack = []
+ self._lastClick = 0., None
+
+ super(Zoom, self).__init__(plot)
+
+ def _areaWithAspectRatio(self, x0, y0, x1, y1):
+ _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels()
+
+ areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1
+
+ if plotH != 0.:
+ plotRatio = plotW / float(plotH)
+ width, height = math.fabs(x1 - x0), math.fabs(y1 - y0)
+
+ if height != 0. and width != 0.:
+ if width / height > plotRatio:
+ areaHeight = width / plotRatio
+ areaX0, areaX1 = x0, x1
+ center = 0.5 * (y0 + y1)
+ areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
+ areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
+ else:
+ areaWidth = height * plotRatio
+ areaY0, areaY1 = y0, y1
+ center = 0.5 * (x0 + x1)
+ areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
+ areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
+
+ return areaX0, areaY0, areaX1, areaY1
+
+ def click(self, x, y, btn):
+ if btn == LEFT_BTN:
+ lastClickTime, lastClickPos = self._lastClick
+
+ # Signal mouse double clicked event first
+ if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
+ # Use position of first click
+ eventDict = prepareMouseSignal('mouseDoubleClicked', 'left',
+ *lastClickPos)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = 0., None
+ else:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'left',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
+
+ # Zoom-in centered on mouse cursor
+ # xMin, xMax = self.plot.getGraphXLimits()
+ # yMin, yMax = self.plot.getGraphYLimits()
+ # y2Min, y2Max = self.plot.getGraphYLimits(axis="right")
+ # self.zoomStack.append((xMin, xMax, yMin, yMax, y2Min, y2Max))
+ # self._zoom(x, y, 2)
+ elif btn == RIGHT_BTN:
+ try:
+ xMin, xMax, yMin, yMax, y2Min, y2Max = self.zoomStack.pop()
+ except IndexError:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'right',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+ else:
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ def beginDrag(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.x0, self.y0 = x, y
+
+ def drag(self, x1, y1):
+ if self.color is None:
+ return # Do not draw zoom area
+
+ dataPos = self.plot.pixelToData(x1, y1)
+ assert dataPos is not None
+
+ if self.plot.isKeepDataAspectRatio():
+ area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1)
+ areaX0, areaY0, areaX1, areaY1 = area
+ areaPoints = ((areaX0, areaY0),
+ (areaX1, areaY0),
+ (areaX1, areaY1),
+ (areaX0, areaY1))
+ areaPoints = numpy.array([self.plot.pixelToData(
+ x, y, check=False) for (x, y) in areaPoints])
+
+ if self.color != 'video inverted':
+ areaColor = list(self.color)
+ areaColor[3] *= 0.25
+ else:
+ areaColor = [1., 1., 1., 1.]
+
+ self.setSelectionArea(areaPoints,
+ fill='none',
+ color=areaColor,
+ name="zoomedArea")
+
+ corners = ((self.x0, self.y0),
+ (self.x0, y1),
+ (x1, y1),
+ (x1, self.y0))
+ corners = numpy.array([self.plot.pixelToData(x, y, check=False)
+ for (x, y) in corners])
+
+ self.setSelectionArea(corners, fill='none', color=self.color)
+
+ def endDrag(self, startPos, endPos):
+ x0, y0 = startPos
+ x1, y1 = endPos
+
+ if x0 != x1 or y0 != y1: # Avoid empty zoom area
+ # Store current zoom state in stack
+ xMin, xMax = self.plot.getGraphXLimits()
+ yMin, yMax = self.plot.getGraphYLimits()
+ y2Min, y2Max = self.plot.getGraphYLimits(axis="right")
+ self.zoomStack.append((xMin, xMax, yMin, yMax, y2Min, y2Max))
+
+ if self.plot.isKeepDataAspectRatio():
+ x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+
+ # Convert to data space and set limits
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+
+ dataPos = self.plot.pixelToData(
+ startPos[0], startPos[1], axis="right", check=False)
+ y2_0 = dataPos[1]
+
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+
+ dataPos = self.plot.pixelToData(
+ endPos[0], endPos[1], axis="right", check=False)
+ y2_1 = dataPos[1]
+
+ xMin, xMax = min(x0, x1), max(x0, x1)
+ yMin, yMax = min(y0, y1), max(y0, y1)
+ y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
+
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ self.resetSelectionArea()
+
+ def cancel(self):
+ if isinstance(self.state, self.states['drag']):
+ self.resetSelectionArea()
+
+
+# Select ######################################################################
+
+class Select(StateMachine, _PlotInteraction):
+ """Base class for drawing selection areas."""
+
+ def __init__(self, plot, parameters, states, state):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ :param dict states: The states of the state machine.
+ :param str state: The name of the initial state.
+ """
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+ StateMachine.__init__(self, states, state)
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+
+class SelectPolygon(Select):
+ """Drawing selection polygon area state machine."""
+
+ DRAG_THRESHOLD_DIST = 4
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self._firstPos = dataPos
+ self.points = [dataPos, dataPos]
+
+ self.updateFirstPoint()
+
+ def updateFirstPoint(self):
+ """Update drawing first point, using self._firstPos"""
+ x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+
+ offset = self.machine.DRAG_THRESHOLD_DIST
+ points = [(x - offset, y - offset),
+ (x - offset, y + offset),
+ (x + offset, y + offset),
+ (x + offset, y - offset)]
+ points = [self.machine.plot.pixelToData(xpix, ypix, check=False)
+ for xpix, ypix in points]
+ self.machine.setSelectionArea(points, fill=None,
+ color=self.machine.color,
+ name='first_point')
+
+ def updateSelectionArea(self):
+ """Update drawing selection area using self.points"""
+ self.machine.setSelectionArea(self.points,
+ fill='hatch',
+ color=self.machine.color)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle)
+ self.updateFirstPoint()
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ # checking if the position is close to the first point
+ # if yes : closing the "loop"
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+
+ # Only allow to close polygon after first point
+ if (len(self.points) > 2 and
+ dx < self.machine.DRAG_THRESHOLD_DIST and
+ dy < self.machine.DRAG_THRESHOLD_DIST):
+ self.machine.resetSelectionArea()
+
+ self.points[-1] = self.points[0]
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+ self.goto('idle')
+ return False
+
+ # Update polygon last point not too close to previous one
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.updateSelectionArea()
+
+ # checking that the new points isnt the same (within range)
+ # of the previous one
+ # This has to be done because sometimes the mouse release event
+ # is caught right after entering the Select state (i.e : press
+ # in Idle state, but with a slightly different position that
+ # the mouse press. So we had the two first vertices that were
+ # almost identical.
+ previousPos = self.machine.plot.dataToPixel(*self.points[-2],
+ check=False)
+ dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
+ if(dx >= self.machine.DRAG_THRESHOLD_DIST or
+ dy >= self.machine.DRAG_THRESHOLD_DIST):
+ self.points.append(dataPos)
+ else:
+ self.points[-1] = dataPos
+
+ return True
+
+ elif btn == RIGHT_BTN:
+ self.machine.resetSelectionArea()
+
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+
+ if (dx < self.machine.DRAG_THRESHOLD_DIST and
+ dy < self.machine.DRAG_THRESHOLD_DIST):
+ self.points[-1] = self.points[0]
+ else:
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.points[-1] = dataPos
+ if self.points[-2] == self.points[-1]:
+ self.points.pop()
+ self.points.append(self.points[0])
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+ self.goto('idle')
+ return False
+
+ return False
+
+ def onMove(self, x, y):
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+ if (dx < self.machine.DRAG_THRESHOLD_DIST and
+ dy < self.machine.DRAG_THRESHOLD_DIST):
+ x, y = firstPos # Snap to first point
+
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.points[-1] = dataPos
+ self.updateSelectionArea()
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': SelectPolygon.Idle,
+ 'select': SelectPolygon.Select
+ }
+ super(SelectPolygon, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.resetSelectionArea()
+
+
+class Select2Points(Select):
+ """Base class for drawing selection based on 2 input points."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('start', x, y)
+ return True
+
+ class Start(State):
+ def enterState(self, x, y):
+ self.machine.beginSelect(x, y)
+
+ def onMove(self, x, y):
+ self.goto('select', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select2Points.Idle,
+ 'start': Select2Points.Start,
+ 'select': Select2Points.Select
+ }
+ super(Select2Points, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def beginSelect(self, x, y):
+ pass
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectRectangle(Select2Points):
+ """Drawing rectangle selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt,
+ (self.startPt[0], dataPos[1]),
+ dataPos,
+ (dataPos[0], self.startPt[1])),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectLine(Select2Points):
+ """Drawing line selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt, dataPos),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class Select1Point(Select):
+ """Base class for drawing selection area based on one input point."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle) # Call select default wheel
+ self.machine.select(x, y)
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select1Point.Idle,
+ 'select': Select1Point.Select
+ }
+ super(Select1Point, self).__init__(plot, parameters, states, 'idle')
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectHLine(Select1Point):
+ """Drawing a horizontal line selection area state machine."""
+ def _hLine(self, y):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ left, _top, width, _height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(left, y, check=False)
+ dataPos2 = self.plot.pixelToData(left + width, y, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._hLine(y)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'hline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'hline',
+ self._hLine(y),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectVLine(Select1Point):
+ """Drawing a vertical line selection area state machine."""
+ def _vLine(self, x):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ _left, top, _width, height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(x, top, check=False)
+ dataPos2 = self.plot.pixelToData(x, top + height, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._vLine(x)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'vline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'vline',
+ self._vLine(x),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class DrawFreeHand(Select):
+ """Interaction for drawing pencil. It display the preview of the pencil
+ before pressing the mouse.
+ """
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+
+ def onLeave(self):
+ self.machine.cancel()
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.__isOut = False
+ self.machine.setFirstPoint(x, y)
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ if self.__isOut:
+ self.machine.resetSelectionArea()
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onEnter(self):
+ self.__isOut = False
+
+ def onLeave(self):
+ self.__isOut = True
+
+ def __init__(self, plot, parameters):
+ # Circle used for pencil preview
+ angle = numpy.arange(13.) * numpy.pi * 2.0 / 13.
+ size = parameters.get('width', 1.) * 0.5
+ self._circle = size * numpy.array((numpy.cos(angle),
+ numpy.sin(angle))).T
+
+ states = {
+ 'idle': DrawFreeHand.Idle,
+ 'select': DrawFreeHand.Select
+ }
+ super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle')
+
+ @property
+ def width(self):
+ return self.parameters.get('width', None)
+
+ def setFirstPoint(self, x, y):
+ self._points = []
+ self.select(x, y)
+
+ def updatePencilShape(self, x, y):
+ center = self.plot.pixelToData(x, y, check=False)
+ assert center is not None
+
+ polygon = center + self._circle
+
+ self.setSelectionArea(polygon, fill='none', color=self.color)
+
+ def select(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] == pos:
+ # Skip same points
+ return
+ self._points.append(pos)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] != pos:
+ # Append if different
+ self._points.append(pos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+ self._points = None
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+ def cancel(self):
+ self.resetSelectionArea()
+
+
+class SelectFreeLine(ClickOrDrag, _PlotInteraction):
+ """Base class for drawing free lines with tools such as pencil."""
+
+ def __init__(self, plot, parameters):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ """
+ # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold
+ self._points = []
+ ClickOrDrag.__init__(self)
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+ def click(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent(x, y, isLast=True)
+
+ def beginDrag(self, x, y):
+ self._processEvent(x, y, isLast=False)
+
+ def drag(self, x, y):
+ self._processEvent(x, y, isLast=False)
+
+ def endDrag(self, startPos, endPos):
+ x, y = endPos
+ self._processEvent(x, y, isLast=True)
+
+ def cancel(self):
+ self.resetSelectionArea()
+ self._points = []
+
+ def _processEvent(self, x, y, isLast):
+ dataPos = self.plot.pixelToData(x, y, check=False)
+ isNewPoint = not self._points or dataPos != self._points[-1]
+
+ if isNewPoint:
+ self._points.append(dataPos)
+
+ if isNewPoint or isLast:
+ eventDict = prepareDrawingSignal(
+ 'drawingFinished' if isLast else 'drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ if not isLast:
+ self.setSelectionArea(self._points, fill='none', color=self.color,
+ shape='polylines')
+ else:
+ self.cancel()
+
+
+# ItemInteraction #############################################################
+
+class ItemsInteraction(ClickOrDrag, _PlotInteraction):
+ """Interaction with items (markers, curves and images).
+
+ This class provides selection and dragging of plot primitives
+ that support those interaction.
+ It is also meant to be combined with the zoom interaction.
+ """
+
+ class Idle(ClickOrDrag.Idle):
+ def __init__(self, *args, **kw):
+ super(ItemsInteraction.Idle, self).__init__(*args, **kw)
+ self._hoverMarker = None
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def onMove(self, x, y):
+ marker = self.machine.plot._pickMarker(x, y)
+ if marker is not None:
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareHoverSignal(
+ marker.getLegend(), 'marker',
+ dataPos, (x, y),
+ marker.isDraggable(),
+ marker.isSelectable())
+ self.machine.plot.notify(**eventDict)
+
+ if marker != self._hoverMarker:
+ self._hoverMarker = marker
+
+ if marker is None:
+ self.machine.plot.setGraphCursorShape()
+
+ elif marker.isDraggable():
+ if isinstance(marker, items.YMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER)
+ elif isinstance(marker, items.XMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR)
+ else:
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL)
+
+ elif marker.isSelectable():
+ self.machine.plot.setGraphCursorShape(CURSOR_POINTING)
+
+ return True
+
+ def __init__(self, plot):
+ _PlotInteraction.__init__(self, plot)
+
+ states = {
+ 'idle': ItemsInteraction.Idle,
+ 'rightClick': ClickOrDrag.RightClick,
+ 'clickOrDrag': ClickOrDrag.ClickOrDrag,
+ 'drag': ClickOrDrag.Drag
+ }
+ StateMachine.__init__(self, states, 'idle')
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ eventDict = self._handleClick(x, y, btn)
+ if eventDict is not None:
+ self.plot.notify(**eventDict)
+
+ def _handleClick(self, x, y, btn):
+ """Perform picking and prepare event if click is handled here
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: event description to send of None if not handling event.
+ :rtype: dict or None
+ """
+
+ if btn == LEFT_BTN:
+ marker = self.plot._pickMarker(
+ x, y, lambda m: m.isSelectable())
+ if marker is not None:
+ xData, yData = marker.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ eventDict = prepareMarkerSignal('markerClicked',
+ 'left',
+ marker.getLegend(),
+ 'marker',
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y), None)
+ return eventDict
+
+ else:
+ picked = self.plot._pickImageOrCurve(
+ x, y, lambda item: item.isSelectable())
+
+ if picked is None:
+ pass
+
+ elif picked[0] == 'curve':
+ curve = picked[1]
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareCurveSignal('left',
+ curve.getLegend(),
+ 'curve',
+ picked[2], picked[3],
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ elif picked[0] == 'image':
+ image = picked[1]
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ # Get corresponding coordinate in image
+ origin = image.getOrigin()
+ scale = image.getScale()
+ column = int((dataPos[0] - origin[0]) / float(scale[0]))
+ row = int((dataPos[1] - origin[1]) / float(scale[1]))
+
+ eventDict = prepareImageSignal('left',
+ image.getLegend(),
+ 'image',
+ column, row,
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ return None
+
+ def _signalMarkerMovingEvent(self, eventType, marker, x, y):
+ assert marker is not None
+
+ xData, yData = marker.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ posDataCursor = self.plot.pixelToData(x, y)
+ assert posDataCursor is not None
+
+ eventDict = prepareMarkerSignal(eventType,
+ 'left',
+ marker.getLegend(),
+ 'marker',
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y),
+ posDataCursor)
+ self.plot.notify(**eventDict)
+
+ def beginDrag(self, x, y):
+ """Handle begining of drag interaction
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :return: True if drag is catched by an item, False otherwise
+ """
+ self._lastPos = self.plot.pixelToData(x, y)
+ assert self._lastPos is not None
+
+ self.imageLegend = None
+ self.markerLegend = None
+ marker = self.plot._pickMarker(
+ x, y, lambda m: m.isDraggable())
+
+ if marker is not None:
+ self.markerLegend = marker.getLegend()
+ self._signalMarkerMovingEvent('markerMoving', marker, x, y)
+ else:
+ picked = self.plot._pickImageOrCurve(
+ x,
+ y,
+ lambda item:
+ hasattr(item, 'isDraggable') and item.isDraggable())
+ if picked is None:
+ self.imageLegend = None
+ self.plot.setGraphCursorShape()
+ return False
+ else:
+ assert picked[0] == 'image' # For now only drag images
+ self.imageLegend = picked[1].getLegend()
+ return True
+
+ def drag(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ xData, yData = dataPos
+
+ if self.markerLegend is not None:
+ marker = self.plot._getMarker(self.markerLegend)
+ if marker is not None:
+ marker.setPosition(xData, yData)
+
+ self._signalMarkerMovingEvent(
+ 'markerMoving', marker, x, y)
+
+ if self.imageLegend is not None:
+ image = self.plot.getImage(self.imageLegend)
+ origin = image.getOrigin()
+ xImage = origin[0] + xData - self._lastPos[0]
+ yImage = origin[1] + yData - self._lastPos[1]
+ image.setOrigin((xImage, yImage))
+
+ self._lastPos = xData, yData
+
+ def endDrag(self, startPos, endPos):
+ if self.markerLegend is not None:
+ marker = self.plot._getMarker(self.markerLegend)
+ posData = list(marker.getPosition())
+ if posData[0] is None:
+ posData[0] = [0, 1]
+ if posData[1] is None:
+ posData[1] = [0, 1]
+
+ eventDict = prepareMarkerSignal(
+ 'markerMoved',
+ 'left',
+ marker.getLegend(),
+ 'marker',
+ marker.isDraggable(),
+ marker.isSelectable(),
+ posData)
+ self.plot.notify(**eventDict)
+
+ self.plot.setGraphCursorShape()
+
+ del self.markerLegend
+ del self.imageLegend
+ del self._lastPos
+
+ def cancel(self):
+ self.plot.setGraphCursorShape()
+
+
+# FocusManager ################################################################
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ for eventHandler in self.machine.eventHandlers:
+ requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ if requestFocus:
+ self.goto('focus', eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.eventHandlers:
+ consumeEvent = eventHandler.handleEvent(*args)
+ if consumeEvent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self._processEvent('release', x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent('wheel', x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn}
+
+ def onPress(self, x, y, btn):
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent('press', x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self.focusBtns.discard(btn)
+ requestFocus = self.eventHandler.handleEvent('release', x, y, btn)
+ if len(self.focusBtns) == 0 and not requestFocus:
+ self.goto('idle')
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=()):
+ self.eventHandlers = list(eventHandlers)
+
+ states = {
+ 'idle': FocusManager.Idle,
+ 'focus': FocusManager.Focus
+ }
+ super(FocusManager, self).__init__(states, 'idle')
+
+ def cancel(self):
+ for handler in self.eventHandlers:
+ handler.cancel()
+
+
+class ZoomAndSelect(ItemsInteraction):
+ """Combine Zoom and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ :param color: The color to use for the zoom area bounding box
+ """
+
+ def __init__(self, plot, color):
+ super(ZoomAndSelect, self).__init__(plot)
+ self._zoom = Zoom(plot, color)
+ self._doZoom = False
+
+ @property
+ def color(self):
+ """Color of the zoom area"""
+ return self._zoom.color
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._zoom.click(x, y, btn)
+
+ def beginDrag(self, x, y):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ """
+ self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y)
+ if self._doZoom:
+ self._zoom.beginDrag(x, y)
+
+ def drag(self, x, y):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ """
+ if self._doZoom:
+ return self._zoom.drag(x, y)
+ else:
+ return super(ZoomAndSelect, self).drag(x, y)
+
+ def endDrag(self, startPos, endPos):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ """
+ if self._doZoom:
+ return self._zoom.endDrag(startPos, endPos)
+ else:
+ return super(ZoomAndSelect, self).endDrag(startPos, endPos)
+
+
+# Interaction mode control ####################################################
+
+class PlotInteraction(object):
+ """Proxy to currently use state machine for interaction.
+
+ This allows to switch interactive mode.
+
+ :param plot: The :class:`Plot` to apply interaction to
+ """
+
+ _DRAW_MODES = {
+ 'polygon': SelectPolygon,
+ 'rectangle': SelectRectangle,
+ 'line': SelectLine,
+ 'vline': SelectVLine,
+ 'hline': SelectHLine,
+ 'polylines': SelectFreeLine,
+ 'pencil': DrawFreeHand,
+ }
+
+ def __init__(self, plot):
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ self.zoomOnWheel = True
+ """True to enable zoom on wheel, False otherwise."""
+
+ # Default event handler
+ self._eventHandler = ItemsInteraction(plot)
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ return {'mode': 'zoom', 'color': self._eventHandler.color}
+
+ elif isinstance(self._eventHandler, Select):
+ result = self._eventHandler.parameters.copy()
+ result['mode'] = 'draw'
+ return result
+
+ elif isinstance(self._eventHandler, Pan):
+ return {'mode': 'pan'}
+
+ else:
+ return {'mode': 'select'}
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ If None, selection area is not drawn.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats or None.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'polylines'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode.
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ assert mode in ('draw', 'pan', 'select', 'zoom')
+
+ plot = self._plot()
+ assert plot is not None
+
+ if color not in (None, 'video inverted'):
+ color = Colors.rgba(color)
+
+ if mode == 'draw':
+ assert shape in self._DRAW_MODES
+ eventHandlerClass = self._DRAW_MODES[shape]
+ parameters = {
+ 'shape': shape,
+ 'label': label,
+ 'color': color,
+ 'width': width,
+ }
+
+ self._eventHandler.cancel()
+ self._eventHandler = eventHandlerClass(plot, parameters)
+
+ elif mode == 'pan':
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = Pan(plot)
+
+ elif mode == 'zoom':
+ # Ignores shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ZoomAndSelect(plot, color)
+
+ else: # Default mode: interaction with plot objects
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ItemsInteraction(plot)
+
+ def handleEvent(self, event, *args, **kwargs):
+ """Forward event to current interactive mode state machine."""
+ if not self.zoomOnWheel and event == 'wheel':
+ return # Discard wheel events
+ self._eventHandler.handleEvent(event, *args, **kwargs)
diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py
new file mode 100644
index 0000000..8042391
--- /dev/null
+++ b/silx/gui/plot/PlotToolButtons.py
@@ -0,0 +1,280 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a set of QToolButton to use with :class:`.PlotWidget`.
+
+The following QToolButton are available:
+
+- :class:`AspectToolButton`
+- :class:`YAxisOriginToolButton`
+- :class:`ProfileToolButton`
+
+"""
+
+__authors__ = ["V. Valls", "H. Payno"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+import logging
+from .. import icons
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotToolButton(qt.QToolButton):
+ """A QToolButton connected to a :class:`.PlotWidget`.
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(PlotToolButton, self).__init__(parent)
+ self._plot = None
+ if plot is not None:
+ self.setPlot(plot)
+
+ def plot(self):
+ """
+ Returns the plot connected to the widget.
+ """
+ return self._plot
+
+ def setPlot(self, plot):
+ """
+ Set the plot connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ """
+ if self._plot is plot:
+ return
+ if self._plot is not None:
+ self._disconnectPlot(self._plot)
+ self._plot = plot
+ if self._plot is not None:
+ self._connectPlot(self._plot)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+ def _disconnectPlot(self, plot):
+ """
+ Called when the plot is disconnected from the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+
+class AspectToolButton(PlotToolButton):
+
+ STATE = None
+ """Lazy loaded states used to feed AspectToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # dont keep ratio
+ self.STATE[False, "icon"] = icons.getQIcon('shape-ellipse-solid')
+ self.STATE[False, "state"] = "Aspect ratio is not kept"
+ self.STATE[False, "action"] = "Do no keep data aspect ratio"
+ # keep ratio
+ self.STATE[True, "icon"] = icons.getQIcon('shape-circle-solid')
+ self.STATE[True, "state"] = "Aspect ratio is kept"
+ self.STATE[True, "action"] = "Keep data aspect ratio"
+
+ super(AspectToolButton, self).__init__(parent=parent, plot=plot)
+
+ keepAction = self._createAction(True)
+ keepAction.triggered.connect(self.keepDataAspectRatio)
+ keepAction.setIconVisibleInMenu(True)
+
+ dontKeepAction = self._createAction(False)
+ dontKeepAction.triggered.connect(self.dontKeepDataAspectRatio)
+ dontKeepAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(keepAction)
+ menu.addAction(dontKeepAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, keepAspectRatio):
+ icon = self.STATE[keepAspectRatio, "icon"]
+ text = self.STATE[keepAspectRatio, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged)
+ self._keepDataAspectRatioChanged(plot.isKeepDataAspectRatio())
+
+ def _disconnectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.disconnect(self._keepDataAspectRatioChanged)
+
+ def keepDataAspectRatio(self):
+ """Configure the plot to keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(True)
+
+ def dontKeepDataAspectRatio(self):
+ """Configure the plot to not keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(False)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, toolTip = self.STATE[aspectRatio, "icon"], self.STATE[aspectRatio, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class YAxisOriginToolButton(PlotToolButton):
+
+ STATE = None
+ """Lazy loaded states used to feed YAxisOriginToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # is down
+ self.STATE[False, "icon"] = icons.getQIcon('plot-ydown')
+ self.STATE[False, "state"] = "Y-axis is oriented downward"
+ self.STATE[False, "action"] = "Orient Y-axis downward"
+ # keep ration
+ self.STATE[True, "icon"] = icons.getQIcon('plot-yup')
+ self.STATE[True, "state"] = "Y-axis is oriented upward"
+ self.STATE[True, "action"] = "Orient Y-axis upward"
+
+ super(YAxisOriginToolButton, self).__init__(parent=parent, plot=plot)
+
+ upwardAction = self._createAction(True)
+ upwardAction.triggered.connect(self.setYAxisUpward)
+ upwardAction.setIconVisibleInMenu(True)
+
+ downwardAction = self._createAction(False)
+ downwardAction.triggered.connect(self.setYAxisDownward)
+ downwardAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(upwardAction)
+ menu.addAction(downwardAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, isUpward):
+ icon = self.STATE[isUpward, "icon"]
+ text = self.STATE[isUpward, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ plot.sigSetYAxisInverted.connect(self._yAxisInvertedChanged)
+ self._yAxisInvertedChanged(plot.isYAxisInverted())
+
+ def _disconnectPlot(self, plot):
+ plot.sigSetYAxisInverted.disconnect(self._yAxisInvertedChanged)
+
+ def setYAxisUpward(self):
+ """Configure the plot to use y-axis upward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.setYAxisInverted(False)
+
+ def setYAxisDownward(self):
+ """Configure the plot to use y-axis downward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.setYAxisInverted(True)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ isUpward = not inverted
+ icon, toolTip = self.STATE[isUpward, "icon"], self.STATE[isUpward, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class ProfileToolButton(PlotToolButton):
+ """Button used in Profile3DToolbar to switch between 2D profile
+ and 1D profile."""
+ STATE = None
+ """Lazy loaded states used to feed ProfileToolButton"""
+
+ sigDimensionChanged = qt.Signal(int)
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {
+ (1, "icon"): icons.getQIcon('profile1D'),
+ (1, "state"): "1D profile is computed on visible image",
+ (1, "action"): "1D profile on visible image",
+ (2, "icon"): icons.getQIcon('profile2D'),
+ (2, "state"): "2D profile is computed, one 1D profile for each image in the stack",
+ (2, "action"): "2D profile on image stack"}
+ # Compute 1D profile
+ # Compute 2D profile
+
+ super(ProfileToolButton, self).__init__(parent=parent, plot=plot)
+
+ profile1DAction = self._createAction(1)
+ profile1DAction.triggered.connect(self.computeProfileIn1D)
+ profile1DAction.setIconVisibleInMenu(True)
+
+ profile2DAction = self._createAction(2)
+ profile2DAction.triggered.connect(self.computeProfileIn2D)
+ profile2DAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(profile1DAction)
+ menu.addAction(profile2DAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ menu.setTitle('Select profile dimension')
+
+ def _createAction(self, profileDimension):
+ icon = self.STATE[profileDimension, "icon"]
+ text = self.STATE[profileDimension, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _profileDimensionChanged(self, profileDimension):
+ """Update icon in toolbar, emit number of dimensions for profile"""
+ self.setIcon(self.STATE[profileDimension, "icon"])
+ self.setToolTip(self.STATE[profileDimension, "state"])
+ self.sigDimensionChanged.emit(profileDimension)
+
+ def computeProfileIn1D(self):
+ self._profileDimensionChanged(1)
+
+ def computeProfileIn2D(self):
+ self._profileDimensionChanged(2)
diff --git a/silx/gui/plot/PlotTools.py b/silx/gui/plot/PlotTools.py
new file mode 100644
index 0000000..7158d0e
--- /dev/null
+++ b/silx/gui/plot/PlotTools.py
@@ -0,0 +1,313 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Set of widgets to associate with a :class:'PlotWidget'.
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/03/2017"
+
+
+import logging
+import numbers
+import traceback
+import weakref
+
+import numpy
+
+from .. import qt
+
+_logger = logging.getLogger(__name__)
+_logger.setLevel(logging.DEBUG)
+
+
+# PositionInfo ################################################################
+
+class PositionInfo(qt.QWidget):
+ """QWidget displaying coords converted from data coords of the mouse.
+
+ Provide this widget with a list of couple:
+
+ - A name to display before the data
+ - A function that takes (x, y) as arguments and returns something that
+ gets converted to a string.
+ If the result is a float it is converted with '%.7g' format.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a QToolBar where to place the
+ PositionInfo widget.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui import qt
+
+ >>> plot = PlotWindow() # Create a PlotWindow to add the widget to
+ >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot
+
+ Then, create the PositionInfo widget and add it to the toolbar.
+ The PositionInfo widget is created with a list of converters, here
+ to display polar coordinates of the mouse position.
+
+ >>> import numpy
+ >>> from silx.gui.plot.PlotTools import PositionInfo
+
+ >>> position = PositionInfo(plot=plot, converters=[
+ ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)),
+ ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))])
+ >>> toolBar.addWidget(position) # Add the widget to the toolbar
+ <...>
+ >>> plot.show() # To display the PlotWindow with the position widget
+
+ :param plot: The PlotWidget this widget is displaying data coords from.
+ :param converters: List of name to display and conversion function from
+ (x, y) in data coords to displayed value.
+ If None, the default, it displays X and Y.
+ :type converters: Iterable of 2-tuple (str, function)
+ :param parent: Parent widget
+ """
+
+ def __init__(self, parent=None, plot=None, converters=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+
+ super(PositionInfo, self).__init__(parent)
+
+ if converters is None:
+ converters = (('X', lambda x, y: x), ('Y', lambda x, y: y))
+
+ self.autoSnapToActiveCurve = False
+ """Toggle snapping use position to active curve.
+
+ - True to snap used coordinates to the active curve if the active curve
+ is displayed with symbols and mouse is close enough.
+ If the mouse is not close to a point of the curve, values are
+ displayed in red.
+ - False (the default) to always use mouse coordinates.
+
+ """
+
+ self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
+
+ # Create a new layout with new widgets
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ # layout.setSpacing(0)
+
+ # Create all QLabel and store them with the corresponding converter
+ for name, func in converters:
+ layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
+
+ contentWidget = qt.QLabel()
+ contentWidget.setText('------')
+ contentWidget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+ contentWidget.setFixedWidth(
+ contentWidget.fontMetrics().width('##############'))
+ layout.addWidget(contentWidget)
+ self._fields.append((contentWidget, name, func))
+
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ # Connect to Plot events
+ plot.sigPlotSignal.connect(self._plotEvent)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plotRef()
+
+ def getConverters(self):
+ """Return the list of converters as 2-tuple (name, function)."""
+ return [(name, func) for _label, name, func in self._fields]
+
+ def _plotEvent(self, event):
+ """Handle events from the Plot.
+
+ :param dict event: Plot event
+ """
+ if event['event'] == 'mouseMoved':
+ x, y = event['x'], event['y'] # Position in data
+ styleSheet = "color: rgb(0, 0, 0);" # Default style
+
+ if self.autoSnapToActiveCurve and self.plot.getGraphCursor():
+ # Check if near active curve with symbols.
+
+ styleSheet = "color: rgb(255, 0, 0);" # Style far from curve
+
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData(copy=False)
+ yData = activeCurve.getYData(copy=False)
+ if activeCurve.getSymbol(): # Only handled if symbols on curve
+ closestIndex = numpy.argmin(
+ pow(xData - x, 2) + pow(yData - y, 2))
+
+ xClosest = xData[closestIndex]
+ yClosest = yData[closestIndex]
+
+ closestInPixels = self.plot.dataToPixel(
+ xClosest, yClosest, axis=activeCurve.getYAxis())
+ if closestInPixels is not None:
+ xPixel, yPixel = event['xpixel'], event['ypixel']
+
+ if (abs(closestInPixels[0] - xPixel) < 5 and
+ abs(closestInPixels[1] - yPixel) < 5):
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+
+ # if close enough, wrap to data point coords
+ x, y = xClosest, yClosest
+
+ for label, name, func in self._fields:
+ label.setStyleSheet(styleSheet)
+
+ try:
+ value = func(x, y)
+ except:
+ label.setText('Error')
+ _logger.error(
+ "Error while converting coordinates (%f, %f)"
+ "with converter '%s'" % (x, y, name))
+ _logger.error(traceback.format_exc())
+ else:
+ if isinstance(value, numbers.Real):
+ value = '%.7g' % value # Use this for floats and int
+ else:
+ value = str(value) # Fallback for other types
+ label.setText(value)
+
+
+# LimitsToolBar ##############################################################
+
+class LimitsToolBar(qt.QToolBar):
+ """QToolBar displaying and controlling the limits of a :class:`PlotWidget`.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow:
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to
+
+ Then, create the LimitsToolBar and add it to the PlotWindow.
+
+ >>> from silx.gui import qt
+ >>> from silx.gui.plot.PlotTools import LimitsToolBar
+
+ >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot
+ >>> plot.show() # To display the PlotWindow with the limits toolbar
+
+ :param parent: See :class:`QToolBar`.
+ :param plot: :class:`PlotWidget` instance on which to operate.
+ :param str title: See :class:`QToolBar`.
+ """
+
+ class _FloatEdit(qt.QLineEdit):
+ """Field to edit a float value."""
+ def __init__(self, value=None, *args, **kwargs):
+ qt.QLineEdit.__init__(self, *args, **kwargs)
+ self.setValidator(qt.QDoubleValidator())
+ self.setFixedWidth(100)
+ self.setAlignment(qt.Qt.AlignLeft)
+ if value is not None:
+ self.setValue(value)
+
+ def value(self):
+ return float(self.text())
+
+ def setValue(self, value):
+ self.setText('%g' % value)
+
+ def __init__(self, parent=None, plot=None, title='Limits'):
+ super(LimitsToolBar, self).__init__(title, parent)
+ assert plot is not None
+ self._plot = plot
+ self._plot.sigPlotSignal.connect(self._plotWidgetSlot)
+
+ self._initWidgets()
+
+ @property
+ def plot(self):
+ """The :class:`PlotWidget` the toolbar is attached to."""
+ return self._plot
+
+ def _initWidgets(self):
+ """Create and init Toolbar widgets."""
+ xMin, xMax = self.plot.getGraphXLimits()
+ yMin, yMax = self.plot.getGraphYLimits()
+
+ self.addWidget(qt.QLabel('Limits: '))
+ self.addWidget(qt.QLabel(' X: '))
+ self._xMinFloatEdit = self._FloatEdit(xMin)
+ self._xMinFloatEdit.editingFinished[()].connect(
+ self._xFloatEditChanged)
+ self.addWidget(self._xMinFloatEdit)
+
+ self._xMaxFloatEdit = self._FloatEdit(xMax)
+ self._xMaxFloatEdit.editingFinished[()].connect(
+ self._xFloatEditChanged)
+ self.addWidget(self._xMaxFloatEdit)
+
+ self.addWidget(qt.QLabel(' Y: '))
+ self._yMinFloatEdit = self._FloatEdit(yMin)
+ self._yMinFloatEdit.editingFinished[()].connect(
+ self._yFloatEditChanged)
+ self.addWidget(self._yMinFloatEdit)
+
+ self._yMaxFloatEdit = self._FloatEdit(yMax)
+ self._yMaxFloatEdit.editingFinished[()].connect(
+ self._yFloatEditChanged)
+ self.addWidget(self._yMaxFloatEdit)
+
+ def _plotWidgetSlot(self, event):
+ """Listen to :class:`PlotWidget` events."""
+ if event['event'] not in ('limitsChanged',):
+ return
+
+ xMin, xMax = self.plot.getGraphXLimits()
+ yMin, yMax = self.plot.getGraphYLimits()
+
+ self._xMinFloatEdit.setValue(xMin)
+ self._xMaxFloatEdit.setValue(xMax)
+ self._yMinFloatEdit.setValue(yMin)
+ self._yMaxFloatEdit.setValue(yMax)
+
+ def _xFloatEditChanged(self):
+ """Handle X limits changed from the GUI."""
+ xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value()
+ if xMax < xMin:
+ xMin, xMax = xMax, xMin
+
+ self.plot.setGraphXLimits(xMin, xMax)
+
+ def _yFloatEditChanged(self):
+ """Handle Y limits changed from the GUI."""
+ yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value()
+ if yMax < yMin:
+ yMin, yMax = yMax, yMin
+
+ self.plot.setGraphYLimits(yMin, yMax)
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
new file mode 100644
index 0000000..5666d56
--- /dev/null
+++ b/silx/gui/plot/PlotWidget.py
@@ -0,0 +1,267 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Qt widget providing Plot API for 1D and 2D data.
+
+This provides the plot API of :class:`silx.gui.plot.Plot.Plot` as a
+Qt widget.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/02/2016"
+
+
+import logging
+
+from . import Plot
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotWidget(qt.QMainWindow, Plot.Plot):
+ """Qt Widget providing a 1D/2D plot.
+
+ This widget is a QMainWindow.
+ It provides Qt signals for the Plot and add supports for panning
+ with arrow keys.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ sigPlotSignal = qt.Signal(object)
+ """Signal for all events of the plot.
+
+ The signal information is provided as a dict.
+ See :class:`.Plot` for documentation of the content of the dict.
+ """
+
+ sigSetYAxisInverted = qt.Signal(bool)
+ """Signal emitted when Y axis orientation has changed"""
+
+ sigSetXAxisLogarithmic = qt.Signal(bool)
+ """Signal emitted when X axis scale has changed"""
+
+ sigSetYAxisLogarithmic = qt.Signal(bool)
+ """Signal emitted when Y axis scale has changed"""
+
+ sigSetXAxisAutoScale = qt.Signal(bool)
+ """Signal emitted when X axis autoscale has changed"""
+
+ sigSetYAxisAutoScale = qt.Signal(bool)
+ """Signal emitted when Y axis autoscale has changed"""
+
+ sigSetKeepDataAspectRatio = qt.Signal(bool)
+ """Signal emitted when plot keep aspect ratio has changed"""
+
+ sigSetGraphGrid = qt.Signal(str)
+ """Signal emitted when plot grid has changed"""
+
+ sigSetGraphCursor = qt.Signal(bool)
+ """Signal emitted when plot crosshair cursor has changed"""
+
+ sigSetPanWithArrowKeys = qt.Signal(bool)
+ """Signal emitted when pan with arrow keys has changed"""
+
+ sigContentChanged = qt.Signal(str, str, str)
+ """Signal emitted when the content of the plot is changed.
+
+ It provides 3 informations:
+
+ - action: The change of the plot: 'add' or 'remove'
+ - kind: The kind of primitive changed:
+ 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker'
+ - legend: The legend of the primitive changed.
+ """
+
+ sigActiveCurveChanged = qt.Signal(object, object)
+ """Signal emitted when the active curve has changed.
+
+ It provides 2 informations:
+
+ - previous: The legend of the previous active curve or None
+ - legend: The legend of the new active curve or None if no curve is active
+ """
+
+ sigActiveImageChanged = qt.Signal(object, object)
+ """Signal emitted when the active image has changed.
+
+ It provides 2 informations:
+
+ - previous: The legend of the previous active image or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigActiveScatterChanged = qt.Signal(object, object)
+ """Signal emitted when the active Scatter has changed.
+
+ It provides following information:
+
+ - previous: The legend of the previous active scatter or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigInteractiveModeChanged = qt.Signal(object)
+ """Signal emitted when the interactive mode has changed
+
+ It provides the source as passed to :meth:`setInteractiveMode`.
+ """
+
+ def __init__(self, parent=None, backend=None,
+ legends=False, callback=None, **kw):
+
+ if kw:
+ _logger.warning(
+ 'deprecated: __init__ extra arguments: %s', str(kw))
+ if legends:
+ _logger.warning('deprecated: __init__ legend argument')
+ if callback:
+ _logger.warning('deprecated: __init__ callback argument')
+
+ self._panWithArrowKeys = True
+
+ qt.QMainWindow.__init__(self, parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('PlotWidget')
+
+ Plot.Plot.__init__(self, parent, backend=backend)
+
+ widget = self.getWidgetHandle()
+ if widget is not None:
+ self.setCentralWidget(widget)
+ else:
+ _logger.warning("Plot backend does not support widget")
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ def notify(self, event, **kwargs):
+ """Override :meth:`Plot.notify` to send Qt signals."""
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self.sigPlotSignal.emit(eventDict)
+
+ if event == 'setYAxisInverted':
+ self.sigSetYAxisInverted.emit(kwargs['state'])
+ elif event == 'setXAxisLogarithmic':
+ self.sigSetXAxisLogarithmic.emit(kwargs['state'])
+ elif event == 'setYAxisLogarithmic':
+ self.sigSetYAxisLogarithmic.emit(kwargs['state'])
+ elif event == 'setXAxisAutoScale':
+ self.sigSetXAxisAutoScale.emit(kwargs['state'])
+ elif event == 'setYAxisAutoScale':
+ self.sigSetYAxisAutoScale.emit(kwargs['state'])
+ elif event == 'setKeepDataAspectRatio':
+ self.sigSetKeepDataAspectRatio.emit(kwargs['state'])
+ elif event == 'setGraphGrid':
+ self.sigSetGraphGrid.emit(kwargs['which'])
+ elif event == 'setGraphCursor':
+ self.sigSetGraphCursor.emit(kwargs['state'])
+ elif event == 'contentChanged':
+ self.sigContentChanged.emit(
+ kwargs['action'], kwargs['kind'], kwargs['legend'])
+ elif event == 'activeCurveChanged':
+ self.sigActiveCurveChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeImageChanged':
+ self.sigActiveImageChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeScatterChanged':
+ self.sigActiveScatterChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'interactiveModeChanged':
+ self.sigInteractiveModeChanged.emit(kwargs['source'])
+ Plot.Plot.notify(self, event, **kwargs)
+
+ # Panning with arrow keys
+
+ def isPanWithArrowKeys(self):
+ """Returns whether or not panning the graph with arrow keys is enable.
+
+ See :meth:`setPanWithArrowKeys`.
+ """
+ return self._panWithArrowKeys
+
+ def setPanWithArrowKeys(self, pan=False):
+ """Enable/Disable panning the graph with arrow keys.
+
+ This grabs the keyboard.
+
+ :param bool pan: True to enable panning, False to disable.
+ """
+ pan = bool(pan)
+ panHasChanged = self._panWithArrowKeys != pan
+
+ self._panWithArrowKeys = pan
+ if not self._panWithArrowKeys:
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ else:
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ if panHasChanged:
+ self.sigSetPanWithArrowKeys.emit(pan)
+
+ # Dict to convert Qt arrow key code to direction str.
+ _ARROWS_TO_PAN_DIRECTION = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+
+ def keyPressEvent(self, event):
+ """Key event handler handling panning on arrow keys.
+
+ Overrides base class implementation.
+ """
+ key = event.key()
+ if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION:
+ self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1)
+
+ # Send a mouse move event to the plot widget to take into account
+ # that even if mouse didn't move on the screen, it moved relative
+ # to the plotted data.
+ qapp = qt.QApplication.instance()
+ event = qt.QMouseEvent(
+ qt.QEvent.MouseMove,
+ self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()),
+ qt.Qt.NoButton,
+ qapp.mouseButtons(),
+ qapp.keyboardModifiers())
+ qapp.sendEvent(self.getWidgetHandle(), event)
+
+ else:
+ # Only call base class implementation when key is not handled.
+ # See QWidget.keyPressEvent for details.
+ super(PlotWidget, self).keyPressEvent(event)
diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py
new file mode 100644
index 0000000..ae25cfd
--- /dev/null
+++ b/silx/gui/plot/PlotWindow.py
@@ -0,0 +1,766 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A :class:`.PlotWidget` with additional toolbars.
+
+The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
+It provides the plot API fully defined in :class:`.Plot`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "27/04/2017"
+
+import collections
+import logging
+
+from silx.utils.decorators import deprecated
+
+from . import PlotWidget
+from . import PlotActions
+from . import PlotToolButtons
+from .PlotTools import PositionInfo
+from .Profile import ProfileToolBar
+from .LegendSelector import LegendsDockWidget
+from .CurvesROIWidget import CurvesROIDockWidget
+from .MaskToolsWidget import MaskToolsDockWidget
+try:
+ from ..console import IPythonDockWidget
+except ImportError:
+ IPythonDockWidget = None
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotWindow(PlotWidget):
+ """Qt Widget providing a 1D/2D plot area and additional tools.
+
+ This widgets inherits from :class:`.PlotWidget` and provides its plot API.
+
+ Initialiser parameters:
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool curveStyle: Toggle visibility of curve style action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`silx.gui.plot.PlotTools.PositionInfo`.
+ :param bool roi: Toggle visibilty of ROI action.
+ :param bool mask: Toggle visibilty of mask action.
+ :param bool fit: Toggle visibilty of fit action.
+ """
+
+ def __init__(self, parent=None, backend=None,
+ resetzoom=True, autoScale=True, logScale=True, grid=True,
+ curveStyle=True, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=True, mask=True, fit=False):
+ super(PlotWindow, self).__init__(parent=parent, backend=backend)
+ if parent is None:
+ self.setWindowTitle('PlotWindow')
+
+ self._dockWidgets = []
+
+ # lazy loaded dock widgets
+ self._legendsDockWidget = None
+ self._curvesROIDockWidget = None
+ self._maskToolsDockWidget = None
+
+ # Init actions
+ self.group = qt.QActionGroup(self)
+ self.group.setExclusive(False)
+
+ self.resetZoomAction = self.group.addAction(PlotActions.ResetZoomAction(self))
+ self.resetZoomAction.setVisible(resetzoom)
+ self.addAction(self.resetZoomAction)
+
+ self.zoomInAction = PlotActions.ZoomInAction(self)
+ self.addAction(self.zoomInAction)
+
+ self.zoomOutAction = PlotActions.ZoomOutAction(self)
+ self.addAction(self.zoomOutAction)
+
+ self.xAxisAutoScaleAction = self.group.addAction(
+ PlotActions.XAxisAutoScaleAction(self))
+ self.xAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.xAxisAutoScaleAction)
+
+ self.yAxisAutoScaleAction = self.group.addAction(
+ PlotActions.YAxisAutoScaleAction(self))
+ self.yAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.yAxisAutoScaleAction)
+
+ self.xAxisLogarithmicAction = self.group.addAction(
+ PlotActions.XAxisLogarithmicAction(self))
+ self.xAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.xAxisLogarithmicAction)
+
+ self.yAxisLogarithmicAction = self.group.addAction(
+ PlotActions.YAxisLogarithmicAction(self))
+ self.yAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.yAxisLogarithmicAction)
+
+ self.gridAction = self.group.addAction(
+ PlotActions.GridAction(self, gridMode='both'))
+ self.gridAction.setVisible(grid)
+ self.addAction(self.gridAction)
+
+ self.curveStyleAction = self.group.addAction(PlotActions.CurveStyleAction(self))
+ self.curveStyleAction.setVisible(curveStyle)
+ self.addAction(self.curveStyleAction)
+
+ self.colormapAction = self.group.addAction(PlotActions.ColormapAction(self))
+ self.colormapAction.setVisible(colormap)
+ self.addAction(self.colormapAction)
+
+ self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=self)
+ self.keepDataAspectRatioButton.setVisible(aspectRatio)
+
+ self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=self)
+ self.yAxisInvertedButton.setVisible(yInverted)
+
+ self.group.addAction(self.getRoiAction())
+ self.getRoiAction().setVisible(roi)
+
+ self.group.addAction(self.getMaskAction())
+ self.getMaskAction().setVisible(mask)
+
+ self._intensityHistoAction = self.group.addAction(
+ PlotActions.PixelIntensitiesHistoAction(self))
+ self._intensityHistoAction.setVisible(False)
+
+ self._medianFilter2DAction = self.group.addAction(
+ PlotActions.MedianFilter2DAction(self))
+ self._medianFilter2DAction.setVisible(False)
+
+ self._medianFilter1DAction = self.group.addAction(
+ PlotActions.MedianFilter1DAction(self))
+ self._medianFilter1DAction.setVisible(False)
+
+ self._separator = qt.QAction('separator', self)
+ self._separator.setSeparator(True)
+ self.group.addAction(self._separator)
+
+ self.copyAction = self.group.addAction(PlotActions.CopyAction(self))
+ self.copyAction.setVisible(copy)
+ self.addAction(self.copyAction)
+
+ self.saveAction = self.group.addAction(PlotActions.SaveAction(self))
+ self.saveAction.setVisible(save)
+ self.addAction(self.saveAction)
+
+ self.printAction = self.group.addAction(PlotActions.PrintAction(self))
+ self.printAction.setVisible(print_)
+ self.addAction(self.printAction)
+
+ self.fitAction = self.group.addAction(PlotActions.FitAction(self))
+ self.fitAction.setVisible(fit)
+ self.addAction(self.fitAction)
+
+ # lazy loaded actions needed by the controlButton menu
+ self._consoleAction = None
+ self._panWithArrowKeysAction = None
+ self._crosshairAction = None
+
+ if control or position:
+ hbox = qt.QHBoxLayout()
+ hbox.setContentsMargins(0, 0, 0, 0)
+
+ if control:
+ self.controlButton = qt.QToolButton()
+ self.controlButton.setText("Options")
+ self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
+ self.controlButton.setAutoRaise(True)
+ self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
+ menu = qt.QMenu(self)
+ menu.aboutToShow.connect(self._customControlButtonMenu)
+ self.controlButton.setMenu(menu)
+
+ hbox.addWidget(self.controlButton)
+
+ if position: # Add PositionInfo widget to the bottom of the plot
+ if isinstance(position, collections.Iterable):
+ # Use position as a set of converters
+ converters = position
+ else:
+ converters = None
+ self.positionWidget = PositionInfo(
+ plot=self, converters=converters)
+ self.positionWidget.autoSnapToActiveCurve = True
+
+ hbox.addWidget(self.positionWidget)
+
+ hbox.addStretch(1)
+ bottomBar = qt.QWidget()
+ bottomBar.setLayout(hbox)
+
+ layout = qt.QVBoxLayout()
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.getWidgetHandle())
+ layout.addWidget(bottomBar)
+ layout.setStretch(0, 1)
+
+ centralWidget = qt.QWidget()
+ centralWidget.setLayout(layout)
+ self.setCentralWidget(centralWidget)
+
+ # Creating the toolbar also create actions for toolbuttons
+ self._toolbar = self._createToolBar(title='Plot', parent=None)
+ self.addToolBar(self._toolbar)
+
+ def getSelectionMask(self):
+ """Return the current mask handled by :attr:`maskToolsDockWidget`.
+
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.getMaskToolsDockWidget().getSelectionMask()
+
+ def setSelectionMask(self, mask):
+ """Set the mask handled by :attr:`maskToolsDockWidget`.
+
+ If the provided mask has not the same dimension as the 'active'
+ image, it will by cropped or padded.
+
+ :param mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :return: True if success, False if failed
+ """
+ return bool(self.getMaskToolsDockWidget().setSelectionMask(mask))
+
+ def _toggleConsoleVisibility(self, is_checked=False):
+ """Create IPythonDockWidget if needed,
+ show it or hide it."""
+ # create widget if needed (first call)
+ if not hasattr(self, '_consoleDockWidget'):
+ available_vars = {"plt": self}
+ banner = "The variable 'plt' is available. Use the 'whos' "
+ banner += "and 'help(plt)' commands for more information.\n\n"
+ self._consoleDockWidget = IPythonDockWidget(
+ available_vars=available_vars,
+ custom_banner=banner,
+ parent=self)
+ self.addTabbedDockWidget(self._consoleDockWidget)
+ self._consoleDockWidget.visibilityChanged.connect(
+ self.getConsoleAction().setChecked)
+
+ self._consoleDockWidget.setVisible(is_checked)
+
+ def _createToolBar(self, title, parent):
+ """Create a QToolBar from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param qt.QWidget parent: See :class:`QToolBar`
+ """
+ toolbar = qt.QToolBar(title, parent)
+
+ # Order widgets with actions
+ objects = self.group.actions()
+
+ # Add push buttons to list
+ index = objects.index(self.colormapAction)
+ objects.insert(index + 1, self.keepDataAspectRatioButton)
+ objects.insert(index + 2, self.yAxisInvertedButton)
+
+ for obj in objects:
+ if isinstance(obj, qt.QAction):
+ toolbar.addAction(obj)
+ else:
+ # Add action for toolbutton in order to allow changing
+ # visibility (see doc QToolBar.addWidget doc)
+ if obj is self.keepDataAspectRatioButton:
+ self.keepDataAspectRatioAction = toolbar.addWidget(obj)
+ elif obj is self.yAxisInvertedButton:
+ self.yAxisInvertedAction = toolbar.addWidget(obj)
+ else:
+ raise RuntimeError()
+ return toolbar
+
+ def toolBar(self):
+ """Return a QToolBar from the QAction of the PlotWindow.
+ """
+ return self._toolbar
+
+ def menu(self, title='Plot', parent=None):
+ """Return a QMenu from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param parent: See :class:`QMenu`
+ """
+ menu = qt.QMenu(title, parent)
+ for action in self.group.actions():
+ menu.addAction(action)
+ return menu
+
+ def _customControlButtonMenu(self):
+ """Display Options button sub-menu."""
+ controlMenu = self.controlButton.menu()
+ controlMenu.clear()
+ controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction())
+ controlMenu.addAction(self.getRoiAction())
+ controlMenu.addAction(self.getMaskAction())
+ controlMenu.addAction(self.getConsoleAction())
+
+ controlMenu.addSeparator()
+ controlMenu.addAction(self.getCrosshairAction())
+ controlMenu.addAction(self.getPanWithArrowKeysAction())
+
+ def addTabbedDockWidget(self, dock_widget):
+ """Add a dock widget as a new tab if there are already dock widgets
+ in the plot. When the first tab is added, the area is chosen
+ depending on the plot geometry:
+ it the window is much wider than it is high, the right dock area
+ is used, else the bottom dock area is used.
+
+ :param dock_widget: Instance of :class:`QDockWidget` to be added.
+ """
+ if dock_widget not in self._dockWidgets:
+ self._dockWidgets.append(dock_widget)
+ if len(self._dockWidgets) == 1:
+ # The first created dock widget must be added to a Widget area
+ width = self.centralWidget().width()
+ height = self.centralWidget().height()
+ if width > (2.0 * height) and width > 1000:
+ area = qt.Qt.RightDockWidgetArea
+ else:
+ area = qt.Qt.BottomDockWidgetArea
+ self.addDockWidget(area, dock_widget)
+ else:
+ # Other dock widgets are added as tabs to the same widget area
+ self.tabifyDockWidget(self._dockWidgets[0],
+ dock_widget)
+
+ # getters for dock widgets
+ @property
+ @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0")
+ def legendsDockWidget(self):
+ return self.getLegendsDockWidget()
+
+ def getLegendsDockWidget(self):
+ """DockWidget with Legend panel"""
+ if self._legendsDockWidget is None:
+ self._legendsDockWidget = LegendsDockWidget(plot=self)
+ self._legendsDockWidget.hide()
+ self.addTabbedDockWidget(self._legendsDockWidget)
+ return self._legendsDockWidget
+
+ @property
+ @deprecated(replacement="getCurvesRoiDockWidget()", since_version="0.4.0")
+ def curvesROIDockWidget(self):
+ return self.getCurvesRoiDockWidget()
+
+ def getCurvesRoiDockWidget(self):
+ """DockWidget with curves' ROI panel (lazy-loaded).
+
+ The widget returned is a :class:`CurvesROIDockWidget`.
+ Its central widget is a :class:`CurvesROIWidget`
+ accessible as :attr:`CurvesROIDockWidget.roiWidget`.
+
+ :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter
+ and a setter for the ROI data:
+
+ - :meth:`CurvesROIWidget.getRois`
+ - :meth:`CurvesROIWidget.setRois`
+ """
+ if self._curvesROIDockWidget is None:
+ self._curvesROIDockWidget = CurvesROIDockWidget(
+ plot=self, name='Regions Of Interest')
+ self._curvesROIDockWidget.hide()
+ self.addTabbedDockWidget(self._curvesROIDockWidget)
+ return self._curvesROIDockWidget
+
+ @property
+ @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0")
+ def maskToolsDockWidget(self):
+ return self.getMaskToolsDockWidget()
+
+ def getMaskToolsDockWidget(self):
+ """DockWidget with image mask panel (lazy-loaded)."""
+ if self._maskToolsDockWidget is None:
+ self._maskToolsDockWidget = MaskToolsDockWidget(
+ plot=self, name='Mask')
+ self._maskToolsDockWidget.hide()
+ self.addTabbedDockWidget(self._maskToolsDockWidget)
+ return self._maskToolsDockWidget
+
+ # getters for actions
+ @property
+ @deprecated(replacement="getConsoleAction()", since_version="0.4.0")
+ def consoleAction(self):
+ return self.getConsoleAction()
+
+ def getConsoleAction(self):
+ """QAction handling the IPython console activation.
+
+ By default, it is connected to a method that initializes the
+ console widget the first time the user clicks the "Console" menu
+ button. The following clicks, after initialization is done,
+ will toggle the visibility of the console widget.
+
+ :rtype: QAction
+ """
+ if self._consoleAction is None:
+ self._consoleAction = qt.QAction('Console', self)
+ self._consoleAction.setCheckable(True)
+ if IPythonDockWidget is not None:
+ self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
+ else:
+ self._consoleAction.setEnabled(False)
+ return self._consoleAction
+
+ @property
+ @deprecated(replacement="getCrosshairAction()", since_version="0.4.0")
+ def crosshairAction(self):
+ return self.getCrosshairAction()
+
+ def getCrosshairAction(self):
+ """Action toggling crosshair cursor mode.
+
+ :rtype: PlotActions.PlotAction
+ """
+ if self._crosshairAction is None:
+ self._crosshairAction = PlotActions.CrosshairAction(self, color='red')
+ return self._crosshairAction
+
+ @property
+ @deprecated(replacement="getMaskAction()", since_version="0.4.0")
+ def maskAction(self):
+ return self.getMaskAction()
+
+ def getMaskAction(self):
+ """QAction toggling image mask dock widget
+
+ :rtype: QAction
+ """
+ return self.getMaskToolsDockWidget().toggleViewAction()
+
+ @property
+ @deprecated(replacement="getPanWithArrowKeysAction()",
+ since_version="0.4.0")
+ def panWithArrowKeysAction(self):
+ return self.getPanWithArrowKeysAction()
+
+ def getPanWithArrowKeysAction(self):
+ """Action toggling pan with arrow keys.
+
+ :rtype: PlotActions.PlotAction
+ """
+ if self._panWithArrowKeysAction is None:
+ self._panWithArrowKeysAction = PlotActions.PanWithArrowKeysAction(self)
+ return self._panWithArrowKeysAction
+
+ @property
+ @deprecated(replacement="getRoiAction()", since_version="0.4.0")
+ def roiAction(self):
+ return self.getRoiAction()
+
+ def getRoiAction(self):
+ """QAction toggling curve ROI dock widget
+
+ :rtype: QAction
+ """
+ return self.getCurvesRoiDockWidget().toggleViewAction()
+
+ def getResetZoomAction(self):
+ """Action resetting the zoom
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.resetZoomAction
+
+ def getZoomInAction(self):
+ """Action to zoom in
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.zoomInAction
+
+ def getZoomOutAction(self):
+ """Action to zoom out
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.zoomOutAction
+
+ def getXAxisAutoScaleAction(self):
+ """Action to toggle the X axis autoscale on zoom reset
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Action to toggle the Y axis autoscale on zoom reset
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Action to toggle logarithmic X axis
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Action to toggle logarithmic Y axis
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Action to toggle the grid visibility in the plot
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.gridAction
+
+ def getCurveStyleAction(self):
+ """Action to change curve line and markers styles
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.curveStyleAction
+
+ def getColormapAction(self):
+ """Action open a colormap dialog to change active image
+ and default colormap.
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Button to toggle aspect ratio preservation
+
+ :rtype: PlotToolButtons.AspectToolButton
+ """
+ return self.keepDataAspectRatioButton
+
+ def getKeepDataAspectRatioAction(self):
+ """Action associated to keepDataAspectRatioButton.
+ Use this to change the visibility of keepDataAspectRatioButton in the
+ toolbar (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.keepDataAspectRatioButton
+
+ def getYAxisInvertedButton(self):
+ """Button to switch the Y axis orientation
+
+ :rtype: PlotToolButtons.YAxisOriginToolButton
+ """
+ return self.yAxisInvertedButton
+
+ def getYAxisInvertedAction(self):
+ """Action associated to yAxisInvertedButton.
+ Use this to change the visibility yAxisInvertedButton in the toolbar.
+ (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.yAxisInvertedAction
+
+ def getIntensityHistogramAction(self):
+ """Action toggling the histogram intensity Plot widget
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self._intensityHistoAction
+
+ def getCopyAction(self):
+ """Action to copy plot snapshot to clipboard
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.copyAction
+
+ def getSaveAction(self):
+ """Action to save plot
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.saveAction
+
+ def getPrintAction(self):
+ """Action to print plot
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.printAction
+
+ def getFitAction(self):
+ """Action to fit selected curve
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self.fitAction
+
+ def getMedianFilter1DAction(self):
+ """Action toggling the 1D median filter
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self._medianFilter1DAction
+
+ def getMedianFilter2DAction(self):
+ """Action toggling the 2D median filter
+
+ :rtype: PlotActions.PlotAction
+ """
+ return self._medianFilter2DAction
+
+
+class Plot1D(PlotWindow):
+ """PlotWindow with tools specific for curves.
+
+ This widgets provides the plot API of :class:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ super(Plot1D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=True,
+ logScale=True, grid=True,
+ curveStyle=True, colormap=False,
+ aspectRatio=False, yInverted=False,
+ copy=True, save=True, print_=True,
+ control=True, position=True,
+ roi=True, mask=False, fit=True)
+ if parent is None:
+ self.setWindowTitle('Plot1D')
+ self.setGraphXLabel('X')
+ self.setGraphYLabel('Y')
+
+
+class Plot2D(PlotWindow):
+ """PlotWindow with a toolbar specific for images.
+
+ This widgets provides the plot API of :~:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ # List of information to display at the bottom of the plot
+ posInfo = [
+ ('X', lambda x, y: x),
+ ('Y', lambda x, y: y),
+ ('Data', self._getImageValue)]
+
+ super(Plot2D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=posInfo,
+ roi=False, mask=True)
+ if parent is None:
+ self.setWindowTitle('Plot2D')
+ self.setGraphXLabel('Columns')
+ self.setGraphYLabel('Rows')
+
+ self.profile = ProfileToolBar(plot=self)
+
+ self.addToolBar(self.profile)
+
+ def _getImageValue(self, x, y):
+ """Get value of top most image at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The value at that point or '-'
+ """
+ value = '-'
+ valueZ = - float('inf')
+
+ for image in self.getAllImages():
+ data = image.getData(copy=False)
+ if image.getZValue() >= valueZ: # This image is over the previous one
+ ox, oy = image.getOrigin()
+ sx, sy = image.getScale()
+ row, col = (y - oy) / sy, (x - ox) / sx
+ if row >= 0 and col >= 0:
+ # Test positive before cast otherwise issue with int(-0.5) = 0
+ row, col = int(row), int(col)
+ if (row < data.shape[0] and col < data.shape[1]):
+ value = data[row, col]
+ valueZ = image.getZValue()
+ return value
+
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+
+ See :class:`silx.gui.plot.Profile.ProfileToolBar`
+ """
+ return self.profile
+
+ @deprecated(replacement="getProfilePlot", since_version="0.5.0")
+ def getProfileWindow(self):
+ return self.getProfilePlot()
+
+ def getProfilePlot(self):
+ """Return plot window used to display profile curve.
+
+ :return: :class:`Plot1D`
+ """
+ return self.profile.getProfilePlot()
diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py
new file mode 100644
index 0000000..a11b3f0
--- /dev/null
+++ b/silx/gui/plot/Profile.py
@@ -0,0 +1,741 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utility functions, toolbars and actions to create profile on images
+and stacks of images"""
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "24/04/2017"
+
+
+import numpy
+
+from silx.image.bilinear import BilinearImage
+
+from .. import icons
+from .. import qt
+from . import items
+from .Colors import cursorColorForColormap
+from .PlotActions import PlotAction
+from .PlotToolButtons import ProfileToolButton
+from .ProfileMainWindow import ProfileMainWindow
+
+from silx.utils.decorators import deprecated
+
+
+def _alignedFullProfile(data, origin, scale, position, roiWidth, axis):
+ """Get a profile along one axis on a stack of images
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param origin: Origin of image in plot (ox, oy)
+ :param scale: Scale of image in plot (sx, sy)
+ :param float position: Position of profile line in plot coords
+ on the axis orthogonal to the profile direction.
+ :param int roiWidth: Width of the profile in image pixels.
+ :param int axis: 0 for horizontal profile, 1 for vertical.
+ :return: profile image + effective ROI area corners in plot coords
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+
+ # Convert from plot to image coords
+ imgPos = int((position - origin[1 - axis]) / scale[1 - axis])
+
+ if axis == 1: # Vertical profile
+ # Transpose image to always do a horizontal profile
+ data = numpy.transpose(data, (0, 2, 1))
+
+ nimages, height, width = data.shape
+
+ roiWidth = min(height, roiWidth) # Clip roi width to image size
+
+ # Get [start, end[ coords of the roi in the data
+ start = int(int(imgPos) + 0.5 - roiWidth / 2.)
+ start = min(max(0, start), height - roiWidth)
+ end = start + roiWidth
+
+ if start < height and end > 0:
+ profile = data[:, max(0, start):min(end, height), :].mean(
+ axis=1, dtype=numpy.float32)
+ else:
+ profile = numpy.zeros((nimages, width), dtype=numpy.float32)
+
+ # Compute effective ROI in plot coords
+ profileBounds = numpy.array(
+ (0, width, width, 0),
+ dtype=numpy.float32) * scale[axis] + origin[axis]
+ roiBounds = numpy.array(
+ (start, start, end, end),
+ dtype=numpy.float32) * scale[1 - axis] + origin[1 - axis]
+
+ if axis == 0: # Horizontal profile
+ area = profileBounds, roiBounds
+ else: # vertical profile
+ area = roiBounds, profileBounds
+
+ return profile, area
+
+
+def _alignedPartialProfile(data, rowRange, colRange, axis):
+ """Mean of a rectangular region (ROI) of a stack of images
+ along a given axis.
+
+ Returned values and all parameters are in image coordinates.
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param rowRange: [min, max[ of ROI rows (upper bound excluded).
+ :type rowRange: 2-tuple of int (min, max) with min < max
+ :param colRange: [min, max[ of ROI columns (upper bound excluded).
+ :type colRange: 2-tuple of int (min, max) with min < max
+ :param int axis: The axis along which to take the profile of the ROI.
+ 0: Sum rows along columns.
+ 1: Sum columns along rows.
+ :return: Profile image along the ROI as the mean of the intersection
+ of the ROI and the image.
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+ assert rowRange[0] < rowRange[1]
+ assert colRange[0] < colRange[1]
+
+ nimages, height, width = data.shape
+
+ # Range aligned with the integration direction
+ profileRange = colRange if axis == 0 else rowRange
+
+ profileLength = abs(profileRange[1] - profileRange[0])
+
+ # Subset of the image to use as intersection of ROI and image
+ rowStart = min(max(0, rowRange[0]), height)
+ rowEnd = min(max(0, rowRange[1]), height)
+ colStart = min(max(0, colRange[0]), width)
+ colEnd = min(max(0, colRange[1]), width)
+
+ imgProfile = numpy.mean(data[:, rowStart:rowEnd, colStart:colEnd],
+ axis=axis + 1, dtype=numpy.float32)
+
+ # Profile including out of bound area
+ profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32)
+
+ # Place imgProfile in full profile
+ offset = - min(0, profileRange[0])
+ profile[:, offset:offset + imgProfile.shape[1]] = imgProfile
+
+ return profile
+
+
+def createProfile(roiInfo, currentData, origin, scale, lineWidth):
+ """Create the profile line for the the given image.
+
+ :param roiInfo: information about the ROI: start point, end point and
+ type ("X", "Y", "D")
+ :param numpy.ndarray currentData: the 2D image or the 3D stack of images
+ on which we compute the profile.
+ :param origin: (ox, oy) the offset from origin
+ :type origin: 2-tuple of float
+ :param scale: (sx, sy) the scale to use
+ :type scale: 2-tuple of float
+ :param int lineWidth: width of the profile line
+ :return: `profile, area, profileName, xLabel`, where:
+ - profile is a 2D array of the profiles of the stack of images.
+ For a single image, the profile is a curve, so this parameter
+ has a shape *(1, len(curve))*
+ - area is a tuple of two 1D arrays with 4 values each. They represent
+ the effective ROI area corners in plot coords.
+ - profileName is a string describing the ROI, meant to be used as
+ title of the profile plot
+ - xLabel is a string describing the meaning of the X axis on the
+ profile plot ("rows", "columns", "distance")
+
+ :rtype: tuple(ndarray, (ndarray, ndarray), str, str)
+ """
+ if currentData is None or roiInfo is None or lineWidth is None:
+ raise ValueError("createProfile called with invalide arguments")
+
+ # force 3D data (stack of images)
+ if len(currentData.shape) == 2:
+ currentData3D = currentData.reshape((1,) + currentData.shape)
+ elif len(currentData.shape) == 3:
+ currentData3D = currentData
+
+ roiWidth = max(1, lineWidth)
+ roiStart, roiEnd, lineProjectionMode = roiInfo
+
+ if lineProjectionMode == 'X': # Horizontal profile on the whole image
+ profile, area = _alignedFullProfile(currentData3D,
+ origin, scale,
+ roiStart[1], roiWidth,
+ axis=0)
+
+ yMin, yMax = min(area[1]), max(area[1]) - 1
+ if roiWidth <= 1:
+ profileName = 'Y = %g' % yMin
+ else:
+ profileName = 'Y = [%g, %g]' % (yMin, yMax)
+ xLabel = 'Columns'
+
+ elif lineProjectionMode == 'Y': # Vertical profile on the whole image
+ profile, area = _alignedFullProfile(currentData3D,
+ origin, scale,
+ roiStart[0], roiWidth,
+ axis=1)
+
+ xMin, xMax = min(area[0]), max(area[0]) - 1
+ if roiWidth <= 1:
+ profileName = 'X = %g' % xMin
+ else:
+ profileName = 'X = [%g, %g]' % (xMin, xMax)
+ xLabel = 'Rows'
+
+ else: # Free line profile
+
+ # Convert start and end points in image coords as (row, col)
+ startPt = ((roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0])
+ endPt = ((roiEnd[1] - origin[1]) / scale[1],
+ (roiEnd[0] - origin[0]) / scale[0])
+
+ if (int(startPt[0]) == int(endPt[0]) or
+ int(startPt[1]) == int(endPt[1])):
+ # Profile is aligned with one of the axes
+
+ # Convert to int
+ startPt = int(startPt[0]), int(startPt[1])
+ endPt = int(endPt[0]), int(endPt[1])
+
+ # Ensure startPt <= endPt
+ if startPt[0] > endPt[0] or startPt[1] > endPt[1]:
+ startPt, endPt = endPt, startPt
+
+ if startPt[0] == endPt[0]: # Row aligned
+ rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth),
+ int(startPt[0] + 0.5 + 0.5 * roiWidth))
+ colRange = startPt[1], endPt[1] + 1
+ profile = _alignedPartialProfile(currentData3D,
+ rowRange, colRange,
+ axis=0)
+
+ else: # Column aligned
+ rowRange = startPt[0], endPt[0] + 1
+ colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth),
+ int(startPt[1] + 0.5 + 0.5 * roiWidth))
+ profile = _alignedPartialProfile(currentData3D,
+ rowRange, colRange,
+ axis=1)
+
+ # Convert ranges to plot coords to draw ROI area
+ area = (
+ numpy.array(
+ (colRange[0], colRange[1], colRange[1], colRange[0]),
+ dtype=numpy.float32) * scale[0] + origin[0],
+ numpy.array(
+ (rowRange[0], rowRange[0], rowRange[1], rowRange[1]),
+ dtype=numpy.float32) * scale[1] + origin[1])
+
+ else: # General case: use bilinear interpolation
+
+ # Ensure startPt <= endPt
+ if (startPt[1] > endPt[1] or (
+ startPt[1] == endPt[1] and startPt[0] > endPt[0])):
+ startPt, endPt = endPt, startPt
+
+ profile = []
+ for slice_idx in range(currentData3D.shape[0]):
+ bilinear = BilinearImage(currentData3D[slice_idx, :, :])
+
+ profile.append(bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ roiWidth))
+ profile = numpy.array(profile)
+
+ # Extend ROI with half a pixel on each end, and
+ # Convert back to plot coords (x, y)
+ length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 +
+ (endPt[1] - startPt[1]) ** 2)
+ dRow = (endPt[0] - startPt[0]) / length
+ dCol = (endPt[1] - startPt[1]) / length
+
+ # Extend ROI with half a pixel on each end
+ startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
+ endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
+
+ # Rotate deltas by 90 degrees to apply line width
+ dRow, dCol = dCol, -dRow
+
+ area = (
+ numpy.array((startPt[1] - 0.5 * roiWidth * dCol,
+ startPt[1] + 0.5 * roiWidth * dCol,
+ endPt[1] + 0.5 * roiWidth * dCol,
+ endPt[1] - 0.5 * roiWidth * dCol),
+ dtype=numpy.float32) * scale[0] + origin[0],
+ numpy.array((startPt[0] - 0.5 * roiWidth * dRow,
+ startPt[0] + 0.5 * roiWidth * dRow,
+ endPt[0] + 0.5 * roiWidth * dRow,
+ endPt[0] - 0.5 * roiWidth * dRow),
+ dtype=numpy.float32) * scale[1] + origin[1])
+
+ y0, x0 = startPt
+ y1, x1 = endPt
+ if x1 == x0 or y1 == y0:
+ profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1)
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth)
+ xLabel = 'Distance'
+
+ return profile, area, profileName, xLabel
+
+
+# ProfileToolBar ##############################################################
+
+class ProfileToolBar(qt.QToolBar):
+ """QToolBar providing profile tools operating on a :class:`PlotWindow`.
+
+ Attributes:
+
+ - plot: Associated :class:`PlotWindow` on which the profile line is drawn.
+ - actionGroup: :class:`QActionGroup` of available actions.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a :class:`ProfileToolBar`.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui.plot.Profile import ProfileToolBar
+
+ >>> plot = PlotWindow() # Create a PlotWindow
+ >>> toolBar = ProfileToolBar(plot=plot) # Create a profile toolbar
+ >>> plot.addToolBar(toolBar) # Add it to plot
+ >>> plot.show() # To display the PlotWindow with the profile toolbar
+
+ :param plot: :class:`PlotWindow` instance on which to operate.
+ :param profileWindow: Plot widget instance where to
+ display the profile curve or None to create one.
+ :param str title: See :class:`QToolBar`.
+ :param parent: See :class:`QToolBar`.
+ """
+ # TODO Make it a QActionGroup instead of a QToolBar
+
+ _POLYGON_LEGEND = '__ProfileToolBar_ROI_Polygon'
+
+ def __init__(self, parent=None, plot=None, profileWindow=None,
+ title='Profile Selection'):
+ super(ProfileToolBar, self).__init__(title, parent)
+ assert plot is not None
+ self.plot = plot
+
+ self._overlayColor = None
+ self._defaultOverlayColor = 'red' # update when active image change
+
+ self._roiInfo = None # Store start and end points and type of ROI
+
+ self._profileWindow = profileWindow
+ """User provided plot widget in which the profile curve is plotted.
+ None if no custom profile plot was provided."""
+
+ self._profileMainWindow = None
+ """Main window providing 2 profile plot widgets for 1D or 2D profiles.
+ The window provides two public methods
+ - :meth:`setProfileDimensions`
+ - :meth:`getPlot`: return handle on the actual plot widget
+ currently being used
+ None if the user specified a custom profile plot window.
+ """
+
+ if self._profileWindow is None:
+ self._profileMainWindow = ProfileMainWindow(self)
+
+ # Actions
+ self.browseAction = qt.QAction(
+ icons.getQIcon('normal'),
+ 'Browsing Mode', None)
+ self.browseAction.setToolTip(
+ 'Enables zooming interaction mode')
+ self.browseAction.setCheckable(True)
+ self.browseAction.triggered[bool].connect(self._browseActionTriggered)
+
+ self.hLineAction = qt.QAction(
+ icons.getQIcon('shape-horizontal'),
+ 'Horizontal Profile Mode', None)
+ self.hLineAction.setToolTip(
+ 'Enables horizontal profile selection mode')
+ self.hLineAction.setCheckable(True)
+ self.hLineAction.toggled[bool].connect(self._hLineActionToggled)
+
+ self.vLineAction = qt.QAction(
+ icons.getQIcon('shape-vertical'),
+ 'Vertical Profile Mode', None)
+ self.vLineAction.setToolTip(
+ 'Enables vertical profile selection mode')
+ self.vLineAction.setCheckable(True)
+ self.vLineAction.toggled[bool].connect(self._vLineActionToggled)
+
+ self.lineAction = qt.QAction(
+ icons.getQIcon('shape-diagonal'),
+ 'Free Line Profile Mode', None)
+ self.lineAction.setToolTip(
+ 'Enables line profile selection mode')
+ self.lineAction.setCheckable(True)
+ self.lineAction.toggled[bool].connect(self._lineActionToggled)
+
+ self.clearAction = qt.QAction(
+ icons.getQIcon('profile-clear'),
+ 'Clear Profile', None)
+ self.clearAction.setToolTip(
+ 'Clear the profile Region of interest')
+ self.clearAction.setCheckable(False)
+ self.clearAction.triggered.connect(self.clearProfile)
+
+ # ActionGroup
+ self.actionGroup = qt.QActionGroup(self)
+ self.actionGroup.addAction(self.browseAction)
+ self.actionGroup.addAction(self.hLineAction)
+ self.actionGroup.addAction(self.vLineAction)
+ self.actionGroup.addAction(self.lineAction)
+
+ self.browseAction.setChecked(True)
+
+ # Add actions to ToolBar
+ self.addAction(self.browseAction)
+ self.addAction(self.hLineAction)
+ self.addAction(self.vLineAction)
+ self.addAction(self.lineAction)
+ self.addAction(self.clearAction)
+
+ # Add width spin box to toolbar
+ self.addWidget(qt.QLabel('W:'))
+ self.lineWidthSpinBox = qt.QSpinBox(self)
+ self.lineWidthSpinBox.setRange(0, 1000)
+ self.lineWidthSpinBox.setValue(1)
+ self.lineWidthSpinBox.valueChanged[int].connect(
+ self._lineWidthSpinBoxValueChangedSlot)
+ self.addWidget(self.lineWidthSpinBox)
+
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ # Enable toolbar only if there is an active image
+ self.setEnabled(self.plot.getActiveImage(just_legend=True) is not None)
+ self.plot.sigActiveImageChanged.connect(
+ self._activeImageChanged)
+
+ # listen to the profile window signals to clear profile polygon on close
+ if self.getProfileMainWindow() is not None:
+ self.getProfileMainWindow().sigClose.connect(self.clearProfile)
+
+ @property
+ @deprecated(replacement="getProfilePlot", since_version="0.5.0")
+ def profileWindow(self):
+ return self.getProfilePlot()
+
+ def getProfilePlot(self):
+ """Return plot widget in which the profile curve or the
+ profile image is plotted.
+ """
+ if self.getProfileMainWindow() is not None:
+ return self.getProfileMainWindow().getPlot()
+
+ # in case the user provided a custom plot for profiles
+ return self._profileWindow
+
+ def getProfileMainWindow(self):
+ """Return window containing the profile curve widget.
+ This can return *None* if a custom profile plot window was
+ specified in the constructor.
+ """
+ return self._profileMainWindow
+
+ def _activeImageChanged(self, previous, legend):
+ """Handle active image change: toggle enabled toolbar, update curve"""
+ self.setEnabled(legend is not None)
+ if legend is not None:
+ # Update default profile color
+ activeImage = self.plot.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ self._defaultOverlayColor = cursorColorForColormap(
+ activeImage.getColormap()['name'])
+ else:
+ self._defaultOverlayColor = 'black'
+
+ self.updateProfile()
+
+ def _lineWidthSpinBoxValueChangedSlot(self, value):
+ """Listen to ROI width widget to refresh ROI and profile"""
+ self.updateProfile()
+
+ def _interactiveModeChanged(self, source):
+ """Handle plot interactive mode changed:
+
+ If changed from elsewhere, disable drawing tool
+ """
+ if source is not self:
+ self.browseAction.setChecked(True)
+
+ def _hLineActionToggled(self, checked):
+ """Handle horizontal line profile action toggle"""
+ if checked:
+ self.plot.setInteractiveMode('draw', shape='hline',
+ color=None, source=self)
+ self.plot.sigPlotSignal.connect(self._plotWindowSlot)
+ else:
+ self.plot.sigPlotSignal.disconnect(self._plotWindowSlot)
+
+ def _vLineActionToggled(self, checked):
+ """Handle vertical line profile action toggle"""
+ if checked:
+ self.plot.setInteractiveMode('draw', shape='vline',
+ color=None, source=self)
+ self.plot.sigPlotSignal.connect(self._plotWindowSlot)
+ else:
+ self.plot.sigPlotSignal.disconnect(self._plotWindowSlot)
+
+ def _lineActionToggled(self, checked):
+ """Handle line profile action toggle"""
+ if checked:
+ self.plot.setInteractiveMode('draw', shape='line',
+ color=None, source=self)
+ self.plot.sigPlotSignal.connect(self._plotWindowSlot)
+ else:
+ self.plot.sigPlotSignal.disconnect(self._plotWindowSlot)
+
+ def _browseActionTriggered(self, checked):
+ """Handle browse action mode triggered by user."""
+ if checked:
+ self.clearProfile()
+ self.plot.setInteractiveMode('zoom', source=self)
+ if self.getProfileMainWindow() is not None:
+ self.getProfileMainWindow().hide()
+
+ def _plotWindowSlot(self, event):
+ """Listen to Plot to handle drawing events to refresh ROI and profile.
+ """
+ if event['event'] not in ('drawingProgress', 'drawingFinished'):
+ return
+
+ checkedAction = self.actionGroup.checkedAction()
+ if checkedAction == self.hLineAction:
+ lineProjectionMode = 'X'
+ elif checkedAction == self.vLineAction:
+ lineProjectionMode = 'Y'
+ elif checkedAction == self.lineAction:
+ lineProjectionMode = 'D'
+ else:
+ return
+
+ roiStart, roiEnd = event['points'][0], event['points'][1]
+
+ self._roiInfo = roiStart, roiEnd, lineProjectionMode
+ self.updateProfile()
+
+ @property
+ def overlayColor(self):
+ """The color to use for the ROI.
+
+ If set to None (the default), the overlay color is adapted to the
+ active image colormap and changes if the active image colormap changes.
+ """
+ return self._overlayColor or self._defaultOverlayColor
+
+ @overlayColor.setter
+ def overlayColor(self, color):
+ self._overlayColor = color
+ self.updateProfile()
+
+ def clearProfile(self):
+ """Remove profile curve and profile area."""
+ self._roiInfo = None
+ self.updateProfile()
+
+ def updateProfile(self):
+ """Update the displayed profile and profile ROI.
+
+ This uses the current active image of the plot and the current ROI.
+ """
+ image = self.plot.getActiveImage()
+ if image is None:
+ return
+
+ # Clean previous profile area, and previous curve
+ self.plot.remove(self._POLYGON_LEGEND, kind='item')
+ self.getProfilePlot().clear()
+ self.getProfilePlot().setGraphTitle('')
+ self.getProfilePlot().setGraphXLabel('X')
+ self.getProfilePlot().setGraphYLabel('Y')
+
+ self._createProfile(currentData=image.getData(copy=False),
+ origin=image.getOrigin(),
+ scale=image.getScale(),
+ colormap=None, # Not used for 2D data
+ z=image.getZValue())
+
+ def _createProfile(self, currentData, origin, scale, colormap, z):
+ """Create the profile line for the the given image.
+
+ :param numpy.ndarray currentData: the image or the stack of images
+ on which we compute the profile
+ :param origin: (ox, oy) the offset from origin
+ :type origin: 2-tuple of float
+ :param scale: (sx, sy) the scale to use
+ :type scale: 2-tuple of float
+ :param dict colormap: The colormap to use
+ :param int z: The z layer of the image
+ """
+ if self._roiInfo is None:
+ return
+
+ profile, area, profileName, xLabel = createProfile(
+ roiInfo=self._roiInfo,
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=self.lineWidthSpinBox.value())
+
+ self.getProfilePlot().setGraphTitle(profileName)
+
+ dataIs3D = len(currentData.shape) > 2
+ if dataIs3D:
+ self.getProfilePlot().addImage(profile,
+ legend=profileName,
+ xlabel=xLabel,
+ ylabel="Frame index (depth)",
+ colormap=colormap)
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ self.getProfilePlot().addCurve(coords,
+ profile[0],
+ legend=profileName,
+ xlabel=xLabel,
+ color=self.overlayColor)
+
+ self.plot.addItem(area[0], area[1],
+ legend=self._POLYGON_LEGEND,
+ color=self.overlayColor,
+ shape='polygon', fill=True,
+ replace=False, z=z + 1)
+
+ self._showProfileMainWindow()
+
+ def _showProfileMainWindow(self):
+ """If profile window was created by this toolbar,
+ try to avoid overlapping with the toolbar's parent window.
+ """
+ profileMainWindow = self.getProfileMainWindow()
+ if profileMainWindow is not None:
+ winGeom = self.window().frameGeometry()
+ qapp = qt.QApplication.instance()
+ screenGeom = qapp.desktop().availableGeometry(self)
+
+ spaceOnLeftSide = winGeom.left()
+ spaceOnRightSide = screenGeom.width() - winGeom.right()
+
+ profileWindowWidth = profileMainWindow.frameGeometry().width()
+ if (profileWindowWidth < spaceOnRightSide or
+ spaceOnRightSide > spaceOnLeftSide):
+ # Place profile on the right
+ profileMainWindow.move(winGeom.right(), winGeom.top())
+ else:
+ # Not enough place on the right, place profile on the left
+ profileMainWindow.move(
+ max(0, winGeom.left() - profileWindowWidth), winGeom.top())
+
+ profileMainWindow.show()
+ else:
+ self.getProfilePlot().show()
+
+ def hideProfileWindow(self):
+ """Hide profile window.
+ """
+ # this method is currently only used by StackView when the perspective
+ # is changed
+ if self.getProfileMainWindow() is not None:
+ self.getProfileMainWindow().hide()
+
+
+class Profile3DToolBar(ProfileToolBar):
+ def __init__(self, parent=None, plot=None, title='Profile Selection'):
+ """QToolBar providing profile tools for an image or a stack of images.
+
+ :param parent: the parent QWidget
+ :param plot: :class:`PlotWindow` instance on which to operate.
+ :param str title: See :class:`QToolBar`.
+ :param parent: See :class:`QToolBar`.
+ """
+ # TODO: add param profileWindow (specify the plot used for profiles)
+ super(Profile3DToolBar, self).__init__(parent=parent, plot=plot,
+ title=title)
+
+ self.profile3dAction = ProfileToolButton(
+ parent=self, plot=self.plot)
+ self.profile3dAction.computeProfileIn2D()
+ self.profile3dAction.setVisible(True)
+ self.addWidget(self.profile3dAction)
+ self.profile3dAction.sigDimensionChanged.connect(self._setProfileType)
+
+ # create the 3D toolbar
+ self._profileType = None
+ self._setProfileType(2)
+
+ def _setProfileType(self, dimensions):
+ """Set the profile type: "1D" for a curve (profile on a single image)
+ or "2D" for an image (profile on a stack of images).
+
+ :param int dimensions: 1 for a "1D" profile or 2 for a "2D" profile
+ """
+ # fixme this assumes that we created _profileMainWindow
+ self._profileType = "1D" if dimensions == 1 else "2D"
+ self.getProfileMainWindow().setProfileType(self._profileType)
+ self.updateProfile()
+
+ def updateProfile(self):
+ """Method overloaded from :class:`ProfileToolBar`,
+ to pass the stack of images instead of just the active image.
+
+ In 1D profile mode, use the regular parent method.
+ """
+ if self._profileType == "1D":
+ super(Profile3DToolBar, self).updateProfile()
+ elif self._profileType == "2D":
+ stackData = self.plot.getCurrentView(copy=False,
+ returnNumpyArray=True)
+ if stackData is None:
+ return
+ self.plot.remove(self._POLYGON_LEGEND, kind='item')
+ self.getProfilePlot().clear()
+ self.getProfilePlot().setGraphTitle('')
+ self.getProfilePlot().setGraphXLabel('X')
+ self.getProfilePlot().setGraphYLabel('Y')
+
+ self._createProfile(currentData=stackData[0],
+ origin=stackData[1]['origin'],
+ scale=stackData[1]['scale'],
+ colormap=stackData[1]['colormap'],
+ z=stackData[1]['z'])
+ else:
+ raise ValueError(
+ "Profile type must be 1D or 2D, not %s" % self._profileType)
diff --git a/silx/gui/plot/ProfileMainWindow.py b/silx/gui/plot/ProfileMainWindow.py
new file mode 100644
index 0000000..835de2c
--- /dev/null
+++ b/silx/gui/plot/ProfileMainWindow.py
@@ -0,0 +1,99 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains a QMainWindow class used to display profile plots.
+"""
+from silx.gui import qt
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/02/2017"
+
+
+class ProfileMainWindow(qt.QMainWindow):
+ """QMainWindow providing 2 plot widgets specialized in
+ 1D and 2D plotting, with different toolbars.
+ Only one of the plots is visible at any given time.
+ """
+ sigProfileDimensionsChanged = qt.Signal(int)
+ """This signal is emitted when :meth:`setProfileDimensions` is called.
+ It carries the number of dimensions for the profile data (1 or 2).
+ It can be used to be notified that the profile plot widget has changed.
+ """
+
+ sigClose = qt.Signal()
+ """Emitted by :meth:`closeEvent` (e.g. when the window is closed
+ through the window manager's close icon)."""
+
+ def __init__(self, parent=None):
+ qt.QMainWindow.__init__(self, parent=parent)
+
+ self.setWindowTitle('Profile window')
+ # plots are created on demand, in self.setProfileDimensions()
+ self._plot1D = None
+ self._plot2D = None
+ # by default, profile is assumed to be a 1D curve
+ self._profileType = None
+ self.setProfileType("1D")
+
+ def setProfileType(self, profileType):
+ """Set which profile plot widget (1D or 2D) is to be used
+
+ :param str profileType: Type of profile data,
+ "1D" for a curve or "2D" for an image
+ """
+ # import here to avoid circular import
+ from .PlotWindow import Plot1D, Plot2D # noqa
+ self._profileType = profileType
+
+ if self._profileType == "1D":
+ if self._plot2D is not None:
+ self._plot2D.setParent(None) # necessary to avoid widget destruction
+ if self._plot1D is None:
+ self._plot1D = Plot1D()
+ self.setCentralWidget(self._plot1D)
+ elif self._profileType == "2D":
+ if self._plot1D is not None:
+ self._plot1D.setParent(None) # necessary to avoid widget destruction
+ if self._plot2D is None:
+ self._plot2D = Plot2D()
+ self.setCentralWidget(self._plot2D)
+ else:
+ raise ValueError("Profile type must be '1D' or '2D'")
+
+ self.sigProfileDimensionsChanged.emit(profileType)
+
+ def getPlot(self):
+ """Return the profile plot widget which is currently in use.
+ This can be the 2D profile plot or the 1D profile plot.
+ """
+ if self._profileType == "2D":
+ return self._plot2D
+ else:
+ return self._plot1D
+
+ def closeEvent(self, qCloseEvent):
+ self.sigClose.emit()
+ qCloseEvent.accept()
diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py
new file mode 100644
index 0000000..793719d
--- /dev/null
+++ b/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -0,0 +1,529 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget`
+
+- :class:`ScatterMask`: Handle scatter mask update and history
+- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask`
+- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+
+from __future__ import division
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+
+import math
+import logging
+import os
+import numpy
+import sys
+
+from .. import qt
+from ...image import shapes
+
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from .Colors import cursorColorForColormap, rgba
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterMask(BaseMask):
+ """A 1D mask for scatter data.
+ """
+ def __init__(self, scatter=None):
+ """
+
+ :param scatter: :class:`silx.gui.plot.items.Scatter` instance
+ """
+ BaseMask.__init__(self, scatter)
+
+ def _getXY(self):
+ x = self._dataItem.getXData(copy=False)
+ y = self._dataItem.getYData(copy=False)
+ return x, y
+
+ def getDataValues(self):
+ """Return scatter data values as a 1D array.
+
+ :rtype: 1D numpy.ndarray
+ """
+ return self._dataItem.getValueData(copy=False)
+
+ def save(self, filename, kind):
+ if kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+ elif kind in ["csv", "txt"]:
+ try:
+ numpy.savetxt(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ def updatePoints(self, level, indices, mask=True):
+ """Mask/Unmask points with given indices.
+
+ :param int level: Mask level to update.
+ :param indices: Sequence or 1D array of indices of points to be
+ updated
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[indices] = level
+ else:
+ # unmask only where mask level is the specified value
+ indices_stencil = numpy.zeros_like(self._mask, dtype=numpy.bool)
+ indices_stencil[indices] = True
+ self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
+ self._notify()
+
+ # update shapes
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (y, x) or (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ polygon = shapes.Polygon(vertices)
+ x, y = self._getXY()
+
+ # TODO: this could be optimized if necessary
+ indices_in_polygon = [idx for idx in range(len(x)) if
+ polygon.is_inside(y[idx], x[idx])]
+
+ self.updatePoints(level, indices_in_polygon, mask)
+
+ def updateRectangle(self, level, y, x, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle
+
+ :param int level: Mask level to update.
+ :param float y: Y coordinate of bottom left corner of the rectangle
+ :param float x: X coordinate of bottom left corner of the rectangle
+ :param float height:
+ :param float width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ vertices = [(y, x),
+ (y + height, x),
+ (y + height, x + width),
+ (y, x + width)]
+ self.updatePolygon(level, vertices, mask)
+
+ def updateDisk(self, level, cy, cx, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param float cy: Disk center (y).
+ :param float cx: Disk center (x).
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ x, y = self._getXY()
+ stencil = (y - cy)**2 + (x - cx)**2 < radius**2
+ self.updateStencil(level, stencil, mask)
+
+ def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
+ """Mask/Unmask points inside a rectangle defined by a line (two
+ end points) and a width.
+
+ :param int level: Mask level to update.
+ :param float y0: Row of the starting point.
+ :param float x0: Column of the starting point.
+ :param float row1: Row of the end point.
+ :param float col1: Column of the end point.
+ :param float width: Width of the line.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ # theta is the angle between the horizontal and the line
+ theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
+ w_over_2_sin_theta = width / 2. * math.sin(theta)
+ w_over_2_cos_theta = width / 2. * math.cos(theta)
+
+ vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
+ (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
+ (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
+ (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)]
+
+ self.updatePolygon(level, vertices, mask)
+
+
+class ScatterMaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for masking data points on a scatter in a
+ :class:`PlotWidget`."""
+
+ def __init__(self, parent=None, plot=None):
+ self._z = 2 # Mask layer in plot
+ self._data_scatter = None
+ """plot Scatter item for data"""
+ self._mask_scatter = None
+ """plot Scatter item for representing the mask"""
+
+ self._mask = ScatterMask()
+
+ super(ScatterMaskToolsWidget, self).__init__(parent, plot)
+
+ self._initWidgets()
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 1-tuple if successful.
+ The mask can be cropped or padded to fit active scatter,
+ the returned shape is that of the scatter data.
+ """
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+
+ if self._data_scatter.getXData(copy=False).shape == (0,) \
+ or mask.shape == self._data_scatter.getXData(copy=False).shape:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ raise ValueError("Mask does not have the same shape as the data")
+
+ # Handle mask refresh on the plot
+
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if len(mask):
+ self.plot.addScatter(self._data_scatter.getXData(),
+ self._data_scatter.getYData(),
+ mask,
+ legend=self._maskName,
+ colormap=self._colormap,
+ z=self._z)
+ self._mask_scatter = self.plot._getItem(kind="scatter",
+ legend=self._maskName)
+ self._mask_scatter.setSymbolSize(
+ self._data_scatter.getSymbolSize() * 4.0
+ )
+ elif self.plot._getItem(kind="scatter",
+ legend=self._maskName) is not None:
+ self.plot.remove(self._maskName, kind='scatter')
+
+ # track widget visibility and plot active image changes
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+ self._activeScatterChanged(None, None) # Init mask + enable/disable widget
+ self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def hideEvent(self, event):
+ self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ if not self.browseAction.isChecked():
+ self.browseAction.trigger() # Disable drawing tool
+
+ if len(self.getSelectionMask(copy=False)):
+ self.plot.sigActiveScatterChanged.connect(
+ self._activeScatterChangedAfterCare)
+
+ def _activeScatterChangedAfterCare(self, previous, next):
+ """Check synchro of active scatter and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to z.
+ """
+ # check that content changed was the active scatter
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getLegend() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ else:
+ colormap = activeScatter.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._z = activeScatter.getZValue() + 1
+ self._data_scatter = activeScatter
+ if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape:
+ # scatter has not the same size, remove mask and stop listening
+ if self.plot._getItem(kind="scatter", legend=self._maskName):
+ self.plot.remove(self._maskName, kind='scatter')
+
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ else:
+ # Refresh in case z changed
+ self._mask.setDataItem(self._data_scatter)
+ self._updatePlotMask()
+
+ def _activeScatterChanged(self, previous, next):
+ """Update widget and mask according to active scatter changes"""
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getLegend() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.setEnabled(False)
+
+ self._data_scatter = None
+ self._mask.reset()
+ self._mask.commit()
+
+ else: # There is an active scatter
+ self.setEnabled(True)
+
+ colormap = activeScatter.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._z = activeScatter.getZValue() + 1
+ self._data_scatter = activeScatter
+ self._mask.setDataItem(self._data_scatter)
+ if self._data_scatter.getXData(copy=False).shape != self.getSelectionMask(copy=False).shape:
+ self._mask.reset(self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+ else:
+ # Refresh in case z changed
+ self._updatePlotMask()
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.',
+ filename)
+ elif extension in ["txt", "csv"]:
+ try:
+ mask = numpy.loadtxt(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy txt file.',
+ filename)
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ self.setSelectionMask(mask, copy=False)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.maskFileDir = os.path.dirname(filename)
+ try:
+ self.load(filename)
+ # except RuntimeWarning as e:
+ # message = e.args[0]
+ # msg = qt.QMessageBox(self)
+ # msg.setIcon(qt.QMessageBox.Warning)
+ # msg.setText("Mask loaded but an operation was applied.\n" + message)
+ # msg.exec_()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec_()
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec_():
+ dialog.close()
+ return
+
+ # convert filter name to extension name with the .
+ extension = dialog.selectedNameFilter().split()[-1][2:-1]
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename):
+ try:
+ os.remove(filename)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec_()
+ return
+
+ self.maskFileDir = os.path.dirname(filename)
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot save file %s\n%s" % (filename, e.args[0]))
+ msg.exec_()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(
+ shape=self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data_scatter.getXData(copy=False)):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if (self._drawingMode == 'rectangle' and
+ event['event'] == 'drawingFinished'):
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event['y'],
+ x=event['x'],
+ height=abs(event['height']),
+ width=abs(event['width']),
+ mask=doMask)
+ self._mask.commit()
+
+ elif (self._drawingMode == 'polygon' and
+ event['event'] == 'drawingFinished'):
+ doMask = self._isMasking()
+ vertices = event['points']
+ vertices = vertices.astype(numpy.int)[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ x, y = event['points'][-1]
+ brushSize = self.pencilSpinBox.value()
+
+ if self._lastPencilPos != (y, x):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ y, x,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, y, x, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = y, x
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active scatter colormap range"""
+ if self._data_scatter is not None:
+ # Update thresholds according to colormap
+ colormap = self._data_scatter.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
+ max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ super(ScatterMaskToolsDockWidget, self).__init__(parent, name)
+ self.setWidget(ScatterMaskToolsWidget(plot=plot))
+ self.widget().sigMaskChanged.connect(self._emitSigMaskChanged)
diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py
new file mode 100644
index 0000000..9bb0cf0
--- /dev/null
+++ b/silx/gui/plot/StackView.py
@@ -0,0 +1,1033 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 3D volume as a stack of 2D images.
+
+The :class:`StackView` class implements this widget.
+
+Basic usage of :class:`StackView` is through the following methods:
+
+- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`StackView.setStack` to update the displayed image.
+
+The :class:`StackView` uses :class:`PlotWindow` and also
+exposes a subset of the :class:`silx.gui.plot.Plot` API for further control
+(plot title, axes labels, ...).
+
+The :class:`StackViewMainWindow` class implements a widget that adds a status
+bar displaying the 3D index and the value under the mouse cursor.
+
+Example::
+
+ import numpy
+ import sys
+ from silx.gui import qt
+ from silx.gui.plot.StackView import StackViewMainWindow
+
+
+ app = qt.QApplication(sys.argv[1:])
+
+ # synthetic data, stack of 100 images of size 200x300
+ mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (100, 200, 300)
+ )
+
+
+ sv = StackViewMainWindow()
+ sv.setColormap("jet", autoscale=True)
+ sv.setStack(mystack)
+ sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
+ "3rd dim (0-299)"])
+ sv.show()
+
+ app.exec_()
+
+"""
+
+__authors__ = ["P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "20/01/2017"
+
+import numpy
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+from silx.gui import qt
+from .. import icons
+from . import items, PlotWindow, PlotActions
+from .Colors import cursorColorForColormap
+from .PlotTools import LimitsToolBar
+from .Profile import Profile3DToolBar
+from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.utils.array_like import DatasetView, ListOfImages
+from silx.math import calibration
+
+
+class StackView(qt.QMainWindow):
+ """Stack view widget, to display and browse through stack of
+ images.
+
+ The profile tool can be switched to "3D" mode, to compute the profile
+ on each image of the stack (not only the active image currently displayed)
+ and display the result as a slice.
+
+ :param QWidget parent: the Qt parent, or None
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.Plot` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`silx.gui.plot.PlotTools.PositionInfo`.
+ :param bool mask: Toggle visibilty of mask action.
+ """
+ # Qt signals
+ valueChanged = qt.Signal(object, object, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+ """
+
+ sigPlaneSelectionChanged = qt.Signal(int)
+ """Signal emitted when there is a change is perspective/displayed axes.
+
+ It provides the perspective as an integer, with the following meaning:
+
+ - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension
+ - 1: axis Y is the 1st dimension, axis X is the 3rd dimension
+ - 2: axis Y is the 1st dimension, axis X is the 2nd dimension
+ """
+
+ sigStackChanged = qt.Signal(int)
+ """Signal emitted when the stack is changed.
+ This happens when a new volume is loaded, or when the current volume
+ is transposed (change in perspective).
+
+ The signal provides the size (number of pixels) of the stack.
+ This will be 0 if the stack is cleared, else it will be a positive
+ integer.
+ """
+
+ def __init__(self, parent=None, resetzoom=True, backend=None,
+ autoScale=False, logScale=False, grid=False,
+ colormap=True, aspectRatio=True, yinverted=True,
+ copy=True, save=True, print_=True, control=False,
+ position=None, mask=True):
+ qt.QMainWindow.__init__(self, parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('StackView')
+
+ self._stack = None
+ """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
+ self.__transposed_view = None
+ """View on :attr:`_stack` with the axes sorted, to have
+ the orthogonal dimension first"""
+ self._perspective = 0
+ """Orthogonal dimension (depth) in :attr:`_stack`"""
+
+ self.__imageLegend = '__StackView__image' + str(id(self))
+ self.__autoscaleCmap = False
+ """Flag to disable/enable colormap auto-scaling
+ based on the min/max values of the entire 3D volume"""
+ self.__dimensionsLabels = ["Dimension 0", "Dimension 1",
+ "Dimension 2"]
+ """These labels are displayed on the X and Y axes.
+ :meth:`setLabels` updates this attribute."""
+
+ self._first_stack_dimension = 0
+ """Used for dimension labels and combobox"""
+
+ central_widget = qt.QWidget(self)
+
+ self._plot = PlotWindow(parent=central_widget, backend=backend,
+ resetzoom=resetzoom, autoScale=autoScale,
+ logScale=logScale, grid=grid,
+ curveStyle=False, colormap=colormap,
+ aspectRatio=aspectRatio, yInverted=yinverted,
+ copy=copy, save=save, print_=print_,
+ control=control, position=position,
+ roi=False, mask=mask)
+ self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
+ self.sigActiveImageChanged = self._plot.sigActiveImageChanged
+ self.sigPlotSignal = self._plot.sigPlotSignal
+
+ self._plot.profile = Profile3DToolBar(parent=self._plot,
+ plot=self)
+ self._plot.addToolBar(self._plot.profile)
+ self._plot.setGraphXLabel('Columns')
+ self._plot.setGraphYLabel('Rows')
+ self._plot.sigPlotSignal.connect(self._plotCallback)
+
+ self.__planeSelection = PlanesWidget(self._plot)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.__setPerspective)
+
+ self._browser_label = qt.QLabel("Image index (Dim0):")
+
+ self._browser = HorizontalSliderWithBrowser(central_widget)
+ self._browser.valueChanged[int].connect(self.__updateFrameNumber)
+ self._browser.setEnabled(False)
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0, 1, 3)
+ layout.addWidget(self.__planeSelection, 1, 0)
+ layout.addWidget(self._browser_label, 1, 1)
+ layout.addWidget(self._browser, 1, 2)
+
+ central_widget.setLayout(layout)
+ self.setCentralWidget(central_widget)
+
+ # clear profile lines when the perspective changes (plane browsed changed)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(
+ self._plot.profile.getProfilePlot().clear)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(
+ self._plot.profile.clearProfile)
+
+ def setOptionVisible(self, isVisible):
+ """
+ Set the visibility of the browsing options.
+
+ :param bool isVisible: True to have the options visible, else False
+ """
+ self._browser.setVisible(isVisible)
+ self.__planeSelection.setVisible(isVisible)
+
+ def _plotCallback(self, eventDict):
+ """Callback for plot events.
+
+ Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
+ cursor location in the plot."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData()
+ height, width = data.shape
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if 0 <= x < width and 0 <= y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+ else:
+ self.valueChanged.emit(float(x), float(y),
+ None)
+
+ def __setPerspective(self, perspective):
+ """Function called when the browsed/orthogonal dimension changes.
+ Updates :attr:`_perspective`, transposes data, updates the plot,
+ emits :attr:`sigPlaneSelectionChanged` and :attr:`sigStackChanged`.
+
+ :param int perspective: the new browsed dimension
+ """
+ if perspective == self._perspective:
+ return
+ else:
+ if perspective > 2 or perspective < 0:
+ raise ValueError(
+ "Perspective must be 0, 1 or 2, not %s" % perspective)
+
+ self._perspective = perspective
+ self.__createTransposedView()
+ self.__updateFrameNumber(self._browser.value())
+ self._plot.resetZoom()
+ self.__updatePlotLabels()
+ self._browser_label.setText("Image index (Dim%d):" %
+ (self._first_stack_dimension + perspective))
+
+ self.sigPlaneSelectionChanged.emit(perspective)
+ self.sigStackChanged.emit(self._stack.size if
+ self._stack is not None else 0)
+
+ def __updatePlotLabels(self):
+ """Update plot axes labels depending on perspective"""
+ y, x = (1, 2) if self._perspective == 0 else \
+ (0, 2) if self._perspective == 1 else (0, 1)
+ self.setGraphXLabel(self.__dimensionsLabels[x])
+ self.setGraphYLabel(self.__dimensionsLabels[y])
+
+ def __createTransposedView(self):
+ """Create the new view on the stack depending on the perspective
+ (set orthogonal axis browsed on the viewer as first dimension)
+ """
+ assert self._stack is not None
+ assert 0 <= self._perspective < 3
+
+ # ensure we have the stack encapsulated in an array like object
+ # having a transpose() method
+ if isinstance(self._stack, numpy.ndarray):
+ self.__transposed_view = self._stack
+
+ elif h5py is not None and isinstance(self._stack, h5py.Dataset) or \
+ isinstance(self._stack, DatasetView):
+ self.__transposed_view = DatasetView(self._stack)
+
+ elif isinstance(self._stack, ListOfImages):
+ self.__transposed_view = ListOfImages(self._stack)
+
+ # transpose the array like object if necessary
+ if self._perspective == 1:
+ self.__transposed_view = self.__transposed_view.transpose((1, 0, 2))
+ elif self._perspective == 2:
+ self.__transposed_view = self.__transposed_view.transpose((2, 0, 1))
+
+ self._browser.setRange(0, self.__transposed_view.shape[0] - 1)
+ self._browser.setValue(0)
+
+ def setFrameNumber(self, number):
+ """Set the frame selection to a specific value\
+
+ :param int number: Number of the frame
+ """
+ self._browser.setValue(number)
+
+ def __updateFrameNumber(self, index):
+ """Update the current image.
+
+ :param index: index of the frame to be displayed
+ """
+ assert self.__transposed_view is not None
+ self._plot.addImage(self.__transposed_view[index, :, :],
+ origin=self._getImageOrigin(),
+ scale=self._getImageScale(),
+ legend=self.__imageLegend,
+ resetzoom=False, replace=False)
+ self._plot.setGraphTitle("Image z=%g" % self._getImageZ(index))
+
+ def _set3DScaleAndOrigin(self, calibrations):
+ """Set scale and origin for all 3 axes, to be used when plotting
+ an image.
+
+ See setStack for parameter documentation
+ """
+ if calibrations is None:
+ self.calibrations3D = (calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration())
+ else:
+ self.calibrations3D = []
+ for calib in calibrations:
+ if hasattr(calib, "__len__") and len(calib) == 2:
+ calib = calibration.LinearCalibration(calib[0], calib[1])
+ elif calib is None:
+ calib = calibration.NoCalibration()
+ elif not isinstance(calib, calibration.AbstractCalibration):
+ raise TypeError("calibration must be a 2-tuple, None or" +
+ " an instance of an AbstractCalibration " +
+ "subclass")
+ self.calibrations3D.append(calib)
+
+ def _getXYZCalibs(self):
+ xy_dims = [0, 1, 2]
+ xy_dims.remove(self._perspective)
+
+ xcalib = self.calibrations3D[max(xy_dims)]
+ ycalib = self.calibrations3D[min(xy_dims)]
+ zcalib = self.calibrations3D[self._perspective]
+
+ return xcalib, ycalib, zcalib
+
+ def _getImageScale(self):
+ """
+ :return: 2-tuple (XScale, YScale) for current image view
+ """
+ xcalib, ycalib, _zcalib = self._getXYZCalibs()
+ return xcalib.get_slope(), ycalib.get_slope()
+
+ def _getImageOrigin(self):
+ """
+ :return: 2-tuple (XOrigin, YOrigin) for current image view
+ """
+ xcalib, ycalib, _zcalib = self._getXYZCalibs()
+ return xcalib(0), ycalib(0)
+
+ def _getImageZ(self, index):
+ """
+ :param idx: 0-based image index in the stack
+ :return: calibrated Z value corresponding to the image idx
+ """
+ _xcalib, _ycalib, zcalib = self._getXYZCalibs()
+ return zcalib(index)
+
+ # public API
+ def setStack(self, stack, perspective=0, reset=True, calibrations=None):
+ """Set the 3D stack.
+
+ The perspective parameter is used to define which dimension of the 3D
+ array is to be used as frame index. The lowest remaining dimension
+ number is the row index of the displayed image (Y axis), and the highest
+ remaining dimension is the column index (X axis).
+
+ :param stack: 3D stack, or `None` to clear plot.
+ :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D
+ numpy arrays, or None.
+ :param int perspective: Dimension for the frame index: 0, 1 or 2.
+ By default, the dimension for the image index is the first
+ dimension of the 3D stack (``perspective=0``).
+ :param bool reset: Whether to reset zoom or not.
+ :param calibrations: Sequence of 3 calibration objects for each axis.
+ These objects can be a subclass of :class:`AbstractCalibration`,
+ or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the
+ slope of a linear calibration (:math:`x \mapsto a + b x`)
+ """
+ if stack is None:
+ self.clear()
+ self.sigStackChanged.emit(0)
+ return
+
+ self._set3DScaleAndOrigin(calibrations)
+
+ # stack as list of 2D arrays: must be converted into an array_like
+ if not isinstance(stack, numpy.ndarray):
+ if h5py is None or not isinstance(stack, h5py.Dataset):
+ try:
+ assert hasattr(stack, "__len__")
+ for img in stack:
+ assert hasattr(img, "shape")
+ assert len(img.shape) == 2
+ except AssertionError:
+ raise ValueError(
+ "Stack must be a 3D array/dataset or a list of " +
+ "2D arrays.")
+ stack = ListOfImages(stack)
+
+ assert len(stack.shape) == 3, "data must be 3D"
+
+ self._stack = stack
+ self.__createTransposedView()
+
+ if perspective != self._perspective:
+ self.__setPerspective(perspective)
+
+ # This call to setColormap redefines the meaning of autoscale
+ # for 3D volume: take global min/max rather than frame min/max
+ if self.__autoscaleCmap:
+ self.setColormap(autoscale=True)
+
+ # init plot
+ self._plot.addImage(self.__transposed_view[0, :, :],
+ legend=self.__imageLegend,
+ colormap=self.getColormap(),
+ origin=self._getImageOrigin(),
+ scale=self._getImageScale(),
+ resetzoom=False)
+ self._plot.setActiveImage(self.__imageLegend)
+ self._plot.setGraphTitle("Image z=%g" % self._getImageZ(0))
+ self.__updatePlotLabels()
+
+ if reset:
+ self._plot.resetZoom()
+
+ # enable and init browser
+ self._browser.setEnabled(True)
+
+ if perspective != self._perspective:
+ self.__planeSelection.setPerspective(perspective)
+ # this causes self.__setPerspective to be called, which emits
+ # sigStackChanged and sigPlaneSelectionChanged
+
+ else:
+ self.sigStackChanged.emit(stack.size)
+
+ def getStack(self, copy=True, returnNumpyArray=False):
+ """Get the original stack, as a 3D array or dataset.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ returnNumpyArray is True, a copy will be made anyway.
+ :param bool returnNumpyArray: If True, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ image = self.getActiveImage()
+ if image is None:
+ return None
+
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ else:
+ colormap = None
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self._stack, copy=copy), params
+
+ # if a list of 2D arrays was cast into a ListOfImages,
+ # return the original list
+ if isinstance(self._stack, ListOfImages):
+ return self._stack.images, params
+
+ return self._stack, params
+
+ def getCurrentView(self, copy=True, returnNumpyArray=False):
+ """Get the stack, as it is currently displayed.
+
+ The first index of the returned stack is always the frame
+ index. If the perspective has been changed in the widget since the
+ data was first loaded, this will be reflected in the order of the
+ dimensions of the returned object.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ `returnNumpyArray` is `True`, a copy will be made anyway.
+ :param bool returnNumpyArray: If `True`, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ image = self.getActiveImage()
+ if image is None:
+ return None
+
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ else:
+ colormap = None
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self.__transposed_view, copy=copy), params
+ return self.__transposed_view, params
+
+ def getActiveImage(self, just_legend=False):
+ """Returns the currently active image object.
+
+ It returns None in case of not having an active image.
+
+ :param bool just_legend: True to get the legend of the image,
+ False (the default) to get the image data and info.
+ Note: :class:`StackView` uses the same legend for all frames.
+ :return: legend or image object
+ :rtype: str or list or None
+ """
+ return self._plot.getActiveImage(just_legend=just_legend)
+
+ def clear(self):
+ """Clear the widget:
+
+ - clear the plot
+ - clear the loaded data volume
+ """
+ self._stack = None
+ self.__transposed_view = None
+ self._perspective = 0
+ self._browser.setEnabled(False)
+ self._plot.clear()
+
+ def resetZoom(self):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+ """
+ self._plot.resetZoom()
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._plot.getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ return self._plot.setGraphTitle(title)
+
+ def setLabels(self, labels=None):
+ """Set the labels to be displayed on the plot axes.
+
+ You must provide a sequence of 3 strings, corresponding to the 3
+ dimensions of the original data volume.
+ The proper label will automatically be selected for each plot axis
+ when the volume is rotated (when different axes are selected as the
+ X and Y axes).
+
+ :param list(str) labels: 3 labels corresponding to the 3 dimensions
+ of the data volumes.
+ """
+
+ default_labels = ["Dimension %d" % self._first_stack_dimension,
+ "Dimension %d" % (self._first_stack_dimension + 1),
+ "Dimension %d" % (self._first_stack_dimension + 2)]
+ if labels is None:
+ new_labels = default_labels
+ else:
+ # filter-out None
+ new_labels = []
+ for i, label in enumerate(labels):
+ new_labels.append(label or default_labels[i])
+
+ self.__dimensionsLabels = new_labels
+ self.__updatePlotLabels()
+
+ def getGraphXLabel(self):
+ """Return the current horizontal axis label as a str."""
+ return self._plot.getGraphXLabel()
+
+ def setGraphXLabel(self, label=None):
+ """Set the plot horizontal axis label.
+
+ :param str label: The horizontal axis label
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
+ self._plot.setGraphXLabel(label)
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current vertical axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ return self._plot.getGraphYLabel(axis)
+
+ def setGraphYLabel(self, label=None, axis='left'):
+ """Set the vertical axis label on the plot.
+
+ :param str label: The Y axis label
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 0 else 0]
+ self._plot.setGraphYLabel(label, axis)
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._plot.setYAxisInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self._backend.isYAxisInverted()
+
+ def getSupportedColormaps(self):
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+ """
+ return self._plot.getSupportedColormaps()
+
+ def getColormap(self):
+ """Get the current colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ # "default" colormap used by addImage when image is added without
+ # specifying a special colormap
+ return self._plot.getDefaultColormap()
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True) or range
+ provided by keys
+ 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or the description of the colormap as a dict.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale or [vmin, vmax] range.
+ Default value of autoscale is True if data is a numpy array,
+ False if data is a h5py dataset.
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ cmapDict = self.getColormap()
+
+ if isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ errmsg = "If colormap is provided as a dict, all other parameters"
+ errmsg += " must not be specified when calling setColormap"
+ assert normalization is None, errmsg
+ assert autoscale is None, errmsg
+ assert vmin is None, errmsg
+ assert vmax is None, errmsg
+ assert colors is None, errmsg
+ cmapDict.update(colormap)
+
+ else:
+ if colormap is not None:
+ cmapDict['name'] = colormap
+ if normalization is not None:
+ cmapDict['normalization'] = normalization
+ if colors is not None:
+ cmapDict['colors'] = colors
+
+ # Default meaning of autoscale is to reset min and max
+ # each time a new image is added to the plot.
+ # We want to use min and max of global volume,
+ # and not change them when browsing slides
+ cmapDict['autoscale'] = False
+
+ if autoscale is None:
+ # set default
+ autoscale = False
+ # TODO: assess cost of computing min/max for large 3D array
+ # if isinstance(self._stack, numpy.ndarray):
+ # autoscale = True
+ # else: # h5py.Dataset
+ # autoscale = False
+ elif autoscale and isinstance(self._stack, h5py.Dataset):
+ # h5py dataset has no min()/max() methods
+ raise RuntimeError(
+ "Cannot auto-scale colormap for a h5py dataset")
+ else:
+ autoscale = autoscale
+ self.__autoscaleCmap = autoscale
+ if autoscale and (self._stack is not None):
+ cmapDict['vmin'] = self._stack.min()
+ cmapDict['vmax'] = self._stack.max()
+ else:
+ if vmin is not None:
+ cmapDict['vmin'] = vmin
+ if vmax is not None:
+ cmapDict['vmax'] = vmax
+
+ cursorColor = cursorColorForColormap(cmapDict['name'])
+ self._plot.setInteractiveMode('zoom', color=cursorColor)
+
+ self._plot.setDefaultColormap(cmapDict)
+
+ # Update active image colormap
+ activeImage = self._plot.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(self.getColormap())
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self._plot.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self._plot.setKeepDataAspectRatio(flag)
+
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+
+ See :class:`silx.gui.plot.Profile.Profile3DToolBar`
+ """
+ return self._plot.profile
+
+ def getProfileWindow1D(self):
+ """Plot window used to display 1D profile curve.
+
+ :return: :class:`Plot1D`
+ """
+ return self._plot.profile.getProfileWindow1D()
+
+ def getProfileWindow2D(self):
+ """Plot window used to display 2D profile image.
+
+ :return: :class:`Plot2D`
+ """
+ return self._plot.profile.getProfileWindow2D()
+
+ # kind of private methods, but needed by Profile
+ def remove(self, legend=None,
+ kind=('curve', 'image', 'item', 'marker')):
+ """See :meth:`Plot.Plot.remove`"""
+ self._plot.remove(legend, kind)
+
+ def setInteractiveMode(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.setInteractiveMode`
+ """
+ self._plot.setInteractiveMode(*args, **kwargs)
+
+ def addItem(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.addItem`
+ """
+ self._plot.addItem(*args, **kwargs)
+
+ def setFirstStackDimension(self, first_stack_dimension):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ old_state = self.__planeSelection.blockSignals(True)
+ self.__planeSelection.setFirstStackDimension(first_stack_dimension)
+ self.__planeSelection.blockSignals(old_state)
+ self._first_stack_dimension = first_stack_dimension
+ self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension)
+
+
+class PlanesWidget(qt.QWidget):
+ """Widget for the plane/perspective selection
+
+ :param parent: the parent QWidget
+ """
+ sigPlaneSelectionChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(PlanesWidget, self).__init__(parent)
+
+ self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ layout0 = qt.QHBoxLayout()
+ self.setLayout(layout0)
+ layout0.setContentsMargins(0, 0, 0, 0)
+
+ layout0.addWidget(qt.QLabel("Axes selection:"))
+
+ # By default, the first dimension (dim0) is the frame index/depth/z,
+ # the second dimension is the image row number/y axis
+ # and the third dimension is the image column index/x axis
+
+ # 1
+ # | 0
+ # |/__2
+ self.qcbAxisSelection = qt.QComboBox(self)
+ self._setCBChoices(first_stack_dimension=0)
+ self.qcbAxisSelection.currentIndexChanged[int].connect(
+ self.__planeSelectionChanged)
+
+ layout0.addWidget(self.qcbAxisSelection)
+
+ def __planeSelectionChanged(self, idx):
+ """Callback function when the combobox selection changes
+
+ idx is the dimension number orthogonal to the slice plane,
+ following the convention:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+ """
+ self.sigPlaneSelectionChanged.emit(idx)
+
+ def _setCBChoices(self, first_stack_dimension):
+ self.qcbAxisSelection.clear()
+
+ dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1,
+ first_stack_dimension + 2)
+ dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 2)
+ dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 1)
+
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1)
+
+ def setFirstStackDimension(self, first_stack_dim):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ self._setCBChoices(first_stack_dim)
+
+ def setPerspective(self, perspective):
+ """Update the combobox selection.
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ self.qcbAxisSelection.setCurrentIndex(perspective)
+
+
+class StackViewMainWindow(StackView):
+ """This class is a :class:`StackView` with a menu, an additional toolbar
+ to set the plot limits, and a status bar to display the value and 3D
+ index of the data samples hovered by the mouse cursor.
+
+ :param QWidget parent: Parent widget, or None
+ """
+ def __init__(self, parent=None):
+ self._dataInfo = None
+ super(StackViewMainWindow, self).__init__(parent)
+ self.setWindowFlags(qt.Qt.Window)
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea,
+ LimitsToolBar(plot=self._plot))
+
+ self.statusBar()
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self._plot.saveAction)
+ menu.addAction(self._plot.printAction)
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self._plot.copyAction)
+ menu.addSeparator()
+ menu.addAction(self._plot.resetZoomAction)
+ menu.addAction(self._plot.colormapAction)
+ menu.addAction(PlotActions.KeepAspectRatioAction(self._plot, self))
+ menu.addAction(PlotActions.YAxisInvertedAction(self._plot, self))
+
+ menu = self.menuBar().addMenu('Profile')
+ menu.addAction(self._plot.profile.browseAction)
+ menu.addAction(self._plot.profile.hLineAction)
+ menu.addAction(self._plot.profile.vLineAction)
+ menu.addAction(self._plot.profile.lineAction)
+ menu.addSeparator()
+ menu.addAction(self._plot.profile.clearAction)
+ self._plot.profile.profile3dAction.computeProfileIn2D()
+ menu.addMenu(self._plot.profile.profile3dAction.menu())
+
+ # Connect to StackView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def _statusBarSlot(self, x, y, value):
+ """Update status bar with coordinates/value from plots."""
+ # todo (after implementing calibration):
+ # - use floats for (x, y, z)
+ # - display both indices (dim0, dim1, dim2) and (x, y, z)
+ msg = "Cursor out of range"
+ if x is not None and y is not None:
+ img_idx = self._browser.value()
+
+ if self._perspective == 0:
+ dim0, dim1, dim2 = img_idx, int(y), int(x)
+ elif self._perspective == 1:
+ dim0, dim1, dim2 = int(y), img_idx, int(x)
+ elif self._perspective == 2:
+ dim0, dim1, dim2 = int(y), int(x), img_idx
+
+ msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2)
+ if value is not None:
+ msg += ', Value: %g' % value
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ def setStack(self, stack, *args, **kwargs):
+ """Set the displayed stack.
+
+ See :meth:`StackView.setStack` for details.
+ """
+ if hasattr(stack, 'dtype') and hasattr(stack, 'shape'):
+ assert len(stack.shape) == 3
+ nframes, height, width = stack.shape
+ self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width,
+ str(stack.dtype))
+ self.statusBar().showMessage(self._dataInfo)
+ else:
+ self._dataInfo = None
+
+ # Set the new stack in StackView widget
+ super(StackViewMainWindow, self).setStack(stack, *args, **kwargs)
+ self.setStatusBar(None)
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py
new file mode 100644
index 0000000..91bbe1c
--- /dev/null
+++ b/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -0,0 +1,1138 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module is a collection of base classes used in modules
+:mod:`.MaskToolsWidget` (images) and :mod:`.ScatterMaskToolsWidget`
+"""
+from __future__ import division
+
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/04/2017"
+
+import os
+
+import numpy
+
+from silx.gui import qt, icons
+from silx.gui.plot.Colors import rgba
+
+
+class BaseMask(qt.QObject):
+ """Base class for :class:`ImageMask` and :class:`ScatterMask`
+
+ A mask field with update operations.
+
+ A mask is an array of the same shape as some underlying data. The mask
+ array stores integer values in the range 0-255, to allow for 254 levels
+ of mask (value 0 is reserved for unmasked data).
+
+ The mask is updated using spatial selection methods: data located inside
+ a selected area is masked with a specified mask level.
+
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the mask has changed"""
+
+ sigUndoable = qt.Signal(bool)
+ """Signal emitted when undo becomes possible/impossible"""
+
+ sigRedoable = qt.Signal(bool)
+ """Signal emitted when redo becomes possible/impossible"""
+
+ def __init__(self, dataItem=None):
+ self.historyDepth = 10
+ """Maximum number of operation stored in history list for undo"""
+ # Init lists for undo/redo
+ self._history = []
+ self._redo = []
+
+ # Store the mask
+ self._mask = numpy.array((), dtype=numpy.uint8)
+
+ # Store the plot item to be masked
+ self._dataItem = None
+ if dataItem is not None:
+ self.setDataItem(dataItem)
+ self.reset(self.getDataValues().shape)
+
+ super(BaseMask, self).__init__()
+
+ def setDataItem(self, item):
+ """Set a data item
+
+ :param item: A plot item, subclass of :class:`silx.gui.plot.items.Item`
+ :return:
+ """
+ self._dataItem = item
+
+ def getDataValues(self):
+ """Return data values, as a numpy array with the same shape
+ as the mask.
+
+ This method must be implemented in a subclass, as the way of
+ accessing data depends on the data item passed to :meth:`setDataItem`
+
+ :return: Data values associated with the data item.
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def _notify(self):
+ """Notify of mask change."""
+ self.sigChanged.emit()
+
+ def getMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the data to be masked.
+ :rtype: numpy.ndarray of uint8
+ """
+ return numpy.array(self._mask, copy=copy)
+
+ def setMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ """
+ self._mask = numpy.array(mask, copy=copy, order='C', dtype=numpy.uint8)
+ self._notify()
+
+ # History control
+ def resetHistory(self):
+ """Reset history"""
+ self._history = [numpy.array(self._mask, copy=True)]
+ self._redo = []
+ self.sigUndoable.emit(False)
+ self.sigRedoable.emit(False)
+
+ def commit(self):
+ """Append the current mask to history if changed"""
+ if (not self._history or self._redo or
+ not numpy.all(numpy.equal(self._mask, self._history[-1]))):
+ if self._redo:
+ self._redo = [] # Reset redo as a new action as been performed
+ self.sigRedoable[bool].emit(False)
+
+ while len(self._history) >= self.historyDepth:
+ self._history.pop(0)
+ self._history.append(numpy.array(self._mask, copy=True))
+
+ if len(self._history) == 2:
+ self.sigUndoable.emit(True)
+
+ def undo(self):
+ """Restore previous mask if any"""
+ if len(self._history) > 1:
+ self._redo.append(self._history.pop())
+ self._mask = numpy.array(self._history[-1], copy=True)
+ self._notify() # Do not store this change in history
+
+ if len(self._redo) == 1: # First redo
+ self.sigRedoable.emit(True)
+ if len(self._history) == 1: # Last value in history
+ self.sigUndoable.emit(False)
+
+ def redo(self):
+ """Restore previously undone modification if any"""
+ if self._redo:
+ self._mask = self._redo.pop()
+ self._history.append(numpy.array(self._mask, copy=True))
+ self._notify()
+
+ if not self._redo: # No more redo
+ self.sigRedoable.emit(False)
+ if len(self._history) == 2: # Something to undo
+ self.sigUndoable.emit(True)
+
+ # Whole mask operations
+
+ def clear(self, level):
+ """Set all values of the given mask level to 0.
+
+ :param int level: Value of the mask to set to 0.
+ """
+ assert 0 < level < 256
+ self._mask[self._mask == level] = 0
+ self._notify()
+
+ def invert(self, level):
+ """Invert mask of the given mask level.
+
+ 0 values become level and level values become 0.
+
+ :param int level: The level to invert.
+ """
+ assert 0 < level < 256
+ masked = self._mask == level
+ self._mask[self._mask == 0] = level
+ self._mask[masked] = 0
+ self._notify()
+
+ def reset(self, shape=None):
+ """Reset the mask to zero and change its shape.
+
+ :param shape: Shape of the new mask with the correct dimensionality
+ with regards to the data dimensionality,
+ or None to have an empty mask
+ :type shape: tuple of int
+ """
+ if shape is None:
+ # assume dimensionality never changes
+ shape = (0, ) * len(self._mask.shape) # empty array
+ shapeChanged = (shape != self._mask.shape)
+ self._mask = numpy.zeros(shape, dtype=numpy.uint8)
+ if shapeChanged:
+ self.resetHistory()
+
+ self._notify()
+
+ # To be implemented
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save (e.g 'npy')
+ :raise Exception: Raised if the file writing fail
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ # update thresholds
+ def updateStencil(self, level, stencil, mask=True):
+ """Mask/Unmask points from boolean mask: all elements that are True
+ in the boolean mask are set to ``level`` (if ``mask=True``) or 0
+ (if ``mask=False``)
+
+ :param int level: Mask level to update.
+ :param stencil: Boolean mask.
+ :type stencil: numpy.array of same dimension as the mask
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[stencil] = level
+ else:
+ self._mask[numpy.logical_and(self._mask == level, stencil)] = 0
+ self._notify()
+
+ def updateBelowThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are below a threshold.
+
+ :param int level:
+ :param float threshold: Threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ self.getDataValues() < threshold,
+ mask)
+
+ def updateBetweenThresholds(self, level, min_, max_, mask=True):
+ """Mask/unmask all points whose values are in a range.
+
+ :param int level:
+ :param float min_: Lower threshold
+ :param float max_: Upper threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ stencil = numpy.logical_and(min_ <= self.getDataValues(),
+ self.getDataValues() <= max_)
+ self.updateStencil(level, stencil, mask)
+
+ def updateAboveThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are above a threshold.
+
+ :param int level: Mask level to update.
+ :param float threshold: Threshold.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ self.getDataValues() > threshold,
+ mask)
+
+ def updateNotFinite(self, level, mask=True):
+ """Mask/unmask all points whose values are not finite.
+
+ :param int level: Mask level to update.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ numpy.logical_not(numpy.isfinite(self.getDataValues())),
+ mask)
+
+ # Drawing operations:
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle, with the given mask level.
+
+ :param int level: Mask level to update, in range 1-255.
+ :param row: Starting row/y of the rectangle
+ :param col: Starting column/x of the rectangle
+ :param height:
+ :param width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask data inside a polygon, with the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col) / (y, x)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows/ordinates (y) of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns/abscissa (x) of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask data located inside a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param crow: Disk center row/ordinate (y).
+ :param ccol: Disk center column/abscissa.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param row0: Row/y of the starting point.
+ :param col0: Column/x of the starting point.
+ :param row1: Row/y of the end point.
+ :param col1: Column/x of the end point.
+ :param width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+
+class BaseMaskToolsWidget(qt.QWidget):
+ """Base class for :class:`MaskToolsWidget` (image mask) and
+ :class:`scatterMaskToolsWidget`"""
+
+ sigMaskChanged = qt.Signal()
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None):
+ # register if the user as force a color for the corresponding mask level
+ self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=numpy.bool)
+ # overlays colors set by the user
+ self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32)
+
+ self._plot = plot
+ self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask
+
+ self._colormap = {
+ 'name': None,
+ 'normalization': 'linear',
+ 'autoscale': False,
+ 'vmin': 0, 'vmax': self._maxLevelNumber,
+ 'colors': None}
+ self._defaultOverlayColor = rgba('gray') # Color of the mask
+ self._setMaskColors(1, 0.5)
+
+ self._mask.sigChanged.connect(self._updatePlotMask)
+ self._mask.sigChanged.connect(self._emitSigMaskChanged)
+
+ self._drawingMode = None # Store current drawing mode
+ self._lastPencilPos = None
+ self._multipleMasks = 'exclusive'
+
+ super(BaseMaskToolsWidget, self).__init__(parent)
+
+ self._maskFileDir = qt.QDir.home().absolutePath()
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ self.sigMaskChanged.emit()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the 'active' plot item.
+ If there is no active image or scatter, an empty array is
+ returned.
+ :rtype: numpy.ndarray of uint8
+ """
+ return self._mask.getMask(copy=copy)
+
+ def multipleMasks(self):
+ """Return the current mode of multiple masks support.
+
+ See :meth:`setMultipleMasks`
+ """
+ return self._multipleMasks
+
+ def setMultipleMasks(self, mode):
+ """Set the mode of multiple masks support.
+
+ Available modes:
+
+ - 'single': Edit a single level of mask
+ - 'exclusive': Supports to 256 levels of non overlapping masks
+
+ :param str mode: The mode to use
+ """
+ assert mode in ('exclusive', 'single')
+ if mode != self._multipleMasks:
+ self._multipleMasks = mode
+ self.levelWidget.setVisible(self._multipleMasks != 'single')
+ self.clearAllBtn.setVisible(self._multipleMasks != 'single')
+
+ @property
+ def maskFileDir(self):
+ """The directory from which to load/save mask from/to files."""
+ if not os.path.isdir(self._maskFileDir):
+ self._maskFileDir = qt.QDir.home().absolutePath()
+ return self._maskFileDir
+
+ @maskFileDir.setter
+ def maskFileDir(self, maskFileDir):
+ self._maskFileDir = str(maskFileDir)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plot
+
+ def setDirection(self, direction=qt.QBoxLayout.LeftToRight):
+ """Set the direction of the layout of the widget
+
+ :param direction: QBoxLayout direction
+ """
+ self.layout().setDirection(direction)
+
+ def _initWidgets(self):
+ """Create widgets"""
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
+ layout.addWidget(self._initMaskGroupBox())
+ layout.addWidget(self._initDrawGroupBox())
+ layout.addWidget(self._initThresholdGroupBox())
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ @staticmethod
+ def _hboxWidget(*widgets, **kwargs):
+ """Place widgets in widget with horizontal layout
+
+ :param widgets: Widgets to position horizontally
+ :param bool stretch: True for trailing stretch (default),
+ False for no trailing stretch
+ :return: A QWidget with a QHBoxLayout
+ """
+ stretch = kwargs.get('stretch', True)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ for widget in widgets:
+ layout.addWidget(widget)
+ if stretch:
+ layout.addStretch(1)
+ widget = qt.QWidget()
+ widget.setLayout(layout)
+ return widget
+
+ def _initTransparencyWidget(self):
+ """ Init the mask transparency widget """
+ transparencyWidget = qt.QWidget(self)
+ grid = qt.QGridLayout()
+ grid.setContentsMargins(0, 0, 0, 0)
+ self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget)
+ self.transparencySlider.setRange(3, 10)
+ self.transparencySlider.setValue(8)
+ self.transparencySlider.setToolTip(
+ 'Set the transparency of the mask display')
+ self.transparencySlider.valueChanged.connect(self._updateColors)
+ grid.addWidget(qt.QLabel('Display:', parent=transparencyWidget), 0, 0)
+ grid.addWidget(self.transparencySlider, 0, 1, 1, 3)
+ grid.addWidget(qt.QLabel('<small><b>Transparent</b></small>', parent=transparencyWidget), 1, 1)
+ grid.addWidget(qt.QLabel('<small><b>Opaque</b></small>', parent=transparencyWidget), 1, 3)
+ transparencyWidget.setLayout(grid)
+ return transparencyWidget
+
+ def _initMaskGroupBox(self):
+ """Init general mask operation widgets"""
+
+ # Mask level
+ self.levelSpinBox = qt.QSpinBox()
+ self.levelSpinBox.setRange(1, self._maxLevelNumber)
+ self.levelSpinBox.setToolTip(
+ 'Choose which mask level is edited.\n'
+ 'A mask can have up to 255 non-overlapping levels.')
+ self.levelSpinBox.valueChanged[int].connect(self._updateColors)
+ self.levelWidget = self._hboxWidget(qt.QLabel('Mask level:'),
+ self.levelSpinBox)
+ # Transparency
+ self.transparencyWidget = self._initTransparencyWidget()
+
+ # Buttons group
+ invertBtn = qt.QPushButton('Invert')
+ invertBtn.setShortcut(qt.Qt.CTRL + qt.Qt.Key_I)
+ invertBtn.setToolTip('Invert current mask <b>%s</b>' %
+ invertBtn.shortcut().toString())
+ invertBtn.clicked.connect(self._handleInvertMask)
+
+ clearBtn = qt.QPushButton('Clear')
+ clearBtn.setShortcut(qt.QKeySequence.Delete)
+ clearBtn.setToolTip('Clear current mask level <b>%s</b>' %
+ clearBtn.shortcut().toString())
+ clearBtn.clicked.connect(self._handleClearMask)
+
+ invertClearWidget = self._hboxWidget(
+ invertBtn, clearBtn, stretch=False)
+
+ undoBtn = qt.QPushButton('Undo')
+ undoBtn.setShortcut(qt.QKeySequence.Undo)
+ undoBtn.setToolTip('Undo last mask change <b>%s</b>' %
+ undoBtn.shortcut().toString())
+ self._mask.sigUndoable.connect(undoBtn.setEnabled)
+ undoBtn.clicked.connect(self._mask.undo)
+
+ redoBtn = qt.QPushButton('Redo')
+ redoBtn.setShortcut(qt.QKeySequence.Redo)
+ redoBtn.setToolTip('Redo last undone mask change <b>%s</b>' %
+ redoBtn.shortcut().toString())
+ self._mask.sigRedoable.connect(redoBtn.setEnabled)
+ redoBtn.clicked.connect(self._mask.redo)
+
+ undoRedoWidget = self._hboxWidget(undoBtn, redoBtn, stretch=False)
+
+ self.clearAllBtn = qt.QPushButton('Clear all')
+ self.clearAllBtn.setToolTip('Clear all mask levels')
+ self.clearAllBtn.clicked.connect(self.resetSelectionMask)
+
+ loadBtn = qt.QPushButton('Load...')
+ loadBtn.clicked.connect(self._loadMask)
+
+ saveBtn = qt.QPushButton('Save...')
+ saveBtn.clicked.connect(self._saveMask)
+
+ self.loadSaveWidget = self._hboxWidget(loadBtn, saveBtn, stretch=False)
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self.levelWidget)
+ layout.addWidget(self.transparencyWidget)
+ layout.addWidget(invertClearWidget)
+ layout.addWidget(undoRedoWidget)
+ layout.addWidget(self.clearAllBtn)
+ layout.addWidget(self.loadSaveWidget)
+ layout.addStretch(1)
+
+ maskGroup = qt.QGroupBox('Mask')
+ maskGroup.setLayout(layout)
+ return maskGroup
+
+ def _initDrawGroupBox(self):
+ """Init drawing tools widgets"""
+ layout = qt.QVBoxLayout()
+
+ # Draw tools
+ self.browseAction = qt.QAction(
+ icons.getQIcon('normal'), 'Browse', None)
+ self.browseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
+ self.browseAction.setToolTip(
+ 'Disables drawing tools, enables zooming interaction mode'
+ ' <b>B</b>')
+ self.browseAction.setCheckable(True)
+ self.browseAction.triggered.connect(self._activeBrowseMode)
+ self.addAction(self.browseAction)
+
+ self.rectAction = qt.QAction(
+ icons.getQIcon('shape-rectangle'), 'Rectangle selection', None)
+ self.rectAction.setToolTip(
+ 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>')
+ self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.rectAction.setCheckable(True)
+ self.rectAction.triggered.connect(self._activeRectMode)
+ self.addAction(self.rectAction)
+
+ self.polygonAction = qt.QAction(
+ icons.getQIcon('shape-polygon'), 'Polygon selection', None)
+ self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
+ self.polygonAction.setToolTip(
+ 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>'
+ 'Left-click to place polygon corners<br>'
+ 'Right-click to place the last corner')
+ self.polygonAction.setCheckable(True)
+ self.polygonAction.triggered.connect(self._activePolygonMode)
+ self.addAction(self.polygonAction)
+
+ self.pencilAction = qt.QAction(
+ icons.getQIcon('draw-pencil'), 'Pencil tool', None)
+ self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P))
+ self.pencilAction.setToolTip(
+ 'Pencil tool: (Un)Mask using a pencil <b>P</b>')
+ self.pencilAction.setCheckable(True)
+ self.pencilAction.triggered.connect(self._activePencilMode)
+ self.addAction(self.polygonAction)
+
+ self.drawActionGroup = qt.QActionGroup(self)
+ self.drawActionGroup.setExclusive(True)
+ self.drawActionGroup.addAction(self.browseAction)
+ self.drawActionGroup.addAction(self.rectAction)
+ self.drawActionGroup.addAction(self.polygonAction)
+ self.drawActionGroup.addAction(self.pencilAction)
+
+ self.browseAction.setChecked(True)
+
+ self.drawButtons = {}
+ for action in self.drawActionGroup.actions():
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ self.drawButtons[action.text()] = btn
+ container = self._hboxWidget(*self.drawButtons.values())
+ layout.addWidget(container)
+
+ # Mask/Unmask radio buttons
+ maskRadioBtn = qt.QRadioButton('Mask')
+ maskRadioBtn.setToolTip(
+ 'Drawing masks with current level. Press <b>Ctrl</b> to unmask')
+ maskRadioBtn.setChecked(True)
+
+ unmaskRadioBtn = qt.QRadioButton('Unmask')
+ unmaskRadioBtn.setToolTip(
+ 'Drawing unmasks with current level. Press <b>Ctrl</b> to mask')
+
+ self.maskStateGroup = qt.QButtonGroup()
+ self.maskStateGroup.addButton(maskRadioBtn, 1)
+ self.maskStateGroup.addButton(unmaskRadioBtn, 0)
+
+ self.maskStateWidget = self._hboxWidget(maskRadioBtn, unmaskRadioBtn)
+ layout.addWidget(self.maskStateWidget)
+
+ # Connect mask state widget visibility with browse action
+ self.maskStateWidget.setHidden(self.browseAction.isChecked())
+ self.browseAction.toggled[bool].connect(
+ self.maskStateWidget.setHidden)
+
+ # Pencil settings
+ self.pencilSetting = self._createPencilSettings(None)
+ self.pencilSetting.setVisible(False)
+ layout.addWidget(self.pencilSetting)
+
+ layout.addStretch(1)
+
+ drawGroup = qt.QGroupBox('Draw tools')
+ drawGroup.setLayout(layout)
+ return drawGroup
+
+ def _createPencilSettings(self, parent=None):
+ pencilSetting = qt.QWidget(parent)
+
+ self.pencilSpinBox = qt.QSpinBox(parent=pencilSetting)
+ self.pencilSpinBox.setRange(1, 1024)
+ pencilToolTip = """Set pencil drawing tool size in pixels of the image
+ on which to make the mask."""
+ self.pencilSpinBox.setToolTip(pencilToolTip)
+
+ self.pencilSlider = qt.QSlider(qt.Qt.Horizontal, parent=pencilSetting)
+ self.pencilSlider.setRange(1, 50)
+ self.pencilSlider.setToolTip(pencilToolTip)
+
+ pencilLabel = qt.QLabel('Pencil size:', parent=pencilSetting)
+
+ layout = qt.QGridLayout()
+ layout.addWidget(pencilLabel, 0, 0)
+ layout.addWidget(self.pencilSpinBox, 0, 1)
+ layout.addWidget(self.pencilSlider, 1, 1)
+ pencilSetting.setLayout(layout)
+
+ self.pencilSpinBox.valueChanged.connect(self._pencilWidthChanged)
+ self.pencilSlider.valueChanged.connect(self._pencilWidthChanged)
+
+ return pencilSetting
+
+ def _initThresholdGroupBox(self):
+ """Init thresholding widgets"""
+ layout = qt.QVBoxLayout()
+
+ # Thresholing
+
+ self.belowThresholdAction = qt.QAction(
+ icons.getQIcon('plot-roi-below'), 'Mask below threshold', None)
+ self.belowThresholdAction.setToolTip(
+ 'Mask image where values are below given threshold')
+ self.belowThresholdAction.setCheckable(True)
+ self.belowThresholdAction.triggered[bool].connect(
+ self._belowThresholdActionTriggered)
+
+ self.betweenThresholdAction = qt.QAction(
+ icons.getQIcon('plot-roi-between'), 'Mask within range', None)
+ self.betweenThresholdAction.setToolTip(
+ 'Mask image where values are within given range')
+ self.betweenThresholdAction.setCheckable(True)
+ self.betweenThresholdAction.triggered[bool].connect(
+ self._betweenThresholdActionTriggered)
+
+ self.aboveThresholdAction = qt.QAction(
+ icons.getQIcon('plot-roi-above'), 'Mask above threshold', None)
+ self.aboveThresholdAction.setToolTip(
+ 'Mask image where values are above given threshold')
+ self.aboveThresholdAction.setCheckable(True)
+ self.aboveThresholdAction.triggered[bool].connect(
+ self._aboveThresholdActionTriggered)
+
+ self.thresholdActionGroup = qt.QActionGroup(self)
+ self.thresholdActionGroup.setExclusive(False)
+ self.thresholdActionGroup.addAction(self.belowThresholdAction)
+ self.thresholdActionGroup.addAction(self.betweenThresholdAction)
+ self.thresholdActionGroup.addAction(self.aboveThresholdAction)
+ self.thresholdActionGroup.triggered.connect(
+ self._thresholdActionGroupTriggered)
+
+ self.loadColormapRangeAction = qt.QAction(
+ icons.getQIcon('view-refresh'), 'Set min-max from colormap', None)
+ self.loadColormapRangeAction.setToolTip(
+ 'Set min and max values from current colormap range')
+ self.loadColormapRangeAction.setCheckable(False)
+ self.loadColormapRangeAction.triggered.connect(
+ self._loadRangeFromColormapTriggered)
+
+ widgets = []
+ for action in self.thresholdActionGroup.actions():
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ widgets.append(btn)
+
+ spacer = qt.QWidget()
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Preferred)
+ widgets.append(spacer)
+
+ loadColormapRangeBtn = qt.QToolButton()
+ loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction)
+ widgets.append(loadColormapRangeBtn)
+
+ container = self._hboxWidget(*widgets, stretch=False)
+ layout.addWidget(container)
+
+ form = qt.QFormLayout()
+
+ self.minLineEdit = qt.QLineEdit()
+ self.minLineEdit.setText('0')
+ self.minLineEdit.setValidator(qt.QDoubleValidator())
+ self.minLineEdit.setEnabled(False)
+ form.addRow('Min:', self.minLineEdit)
+
+ self.maxLineEdit = qt.QLineEdit()
+ self.maxLineEdit.setText('0')
+ self.maxLineEdit.setValidator(qt.QDoubleValidator())
+ self.maxLineEdit.setEnabled(False)
+ form.addRow('Max:', self.maxLineEdit)
+
+ self.applyMaskBtn = qt.QPushButton('Apply mask')
+ self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
+ self.applyMaskBtn.setEnabled(False)
+ form.addRow(self.applyMaskBtn)
+
+ self.maskNanBtn = qt.QPushButton('Mask not finite values')
+ self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
+ self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
+ form.addRow(self.maskNanBtn)
+
+ thresholdWidget = qt.QWidget()
+ thresholdWidget.setLayout(form)
+ layout.addWidget(thresholdWidget)
+
+ layout.addStretch(1)
+
+ self.thresholdGroup = qt.QGroupBox('Threshold')
+ self.thresholdGroup.setLayout(layout)
+ return self.thresholdGroup
+
+ # track widget visibility and plot active image changes
+
+ def changeEvent(self, event):
+ """Reset drawing action when disabling widget"""
+ if (event.type() == qt.QEvent.EnabledChange and
+ not self.isEnabled() and
+ not self.browseAction.isChecked()):
+ self.browseAction.trigger() # Disable drawing tool
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy'
+ :raise Exception: Raised if the process fails
+ """
+ self._mask.save(filename, kind)
+
+ def getCurrentMaskColor(self):
+ """Returns the color of the current selected level.
+
+ :rtype: A tuple or a python array
+ """
+ currentLevel = self.levelSpinBox.value()
+ if self._defaultColors[currentLevel]:
+ return self._defaultOverlayColor
+ else:
+ return self._overlayColors[currentLevel].tolist()
+
+ def _setMaskColors(self, level, alpha):
+ """Set-up the mask colormap to highlight current mask level.
+
+ :param int level: The mask level to highlight
+ :param float alpha: Alpha level of mask in [0., 1.]
+ """
+ assert 0 < level <= self._maxLevelNumber
+
+ colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32)
+
+ # Set color
+ colors[:, :3] = self._defaultOverlayColor[:3]
+
+ # check if some colors has been directly set by the user
+ mask = numpy.equal(self._defaultColors, False)
+ colors[mask, :3] = self._overlayColors[mask, :3]
+
+ # Set alpha
+ colors[:, -1] = alpha / 2.
+
+ # Set highlighted level color
+ colors[level, 3] = alpha
+
+ # Set no mask level
+ colors[0] = (0., 0., 0., 0.)
+
+ self._colormap['colors'] = colors
+
+ def resetMaskColors(self, level=None):
+ """Reset the mask color at the given level to be defaultColors
+
+ :param level:
+ The index of the mask for which we want to reset the color.
+ If none we will reset color for all masks.
+ """
+ if level is None:
+ self._defaultColors[level] = True
+ else:
+ self._defaultColors[:] = True
+
+ self._updateColors()
+
+ def setMaskColors(self, rgb, level=None):
+ """Set the masks color
+
+ :param rgb: The rgb color
+ :param level:
+ The index of the mask for which we want to change the color.
+ If none set this color for all the masks
+ """
+ if level is None:
+ self._overlayColors[:] = rgb
+ self._defaultColors[:] = False
+ else:
+ self._overlayColors[level] = rgb
+ self._defaultColors[level] = False
+
+ self._updateColors()
+
+ def getMaskColors(self):
+ """masks colors getter"""
+ return self._overlayColors
+
+ def _updateColors(self, *args):
+ """Rebuild mask colormap when selected level or transparency change"""
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+ self._updatePlotMask()
+ self._updateInteractiveMode()
+
+ def _pencilWidthChanged(self, width):
+
+ old = self.pencilSpinBox.blockSignals(True)
+ try:
+ self.pencilSpinBox.setValue(width)
+ finally:
+ self.pencilSpinBox.blockSignals(old)
+
+ old = self.pencilSlider.blockSignals(True)
+ try:
+ self.pencilSlider.setValue(width)
+ finally:
+ self.pencilSlider.blockSignals(old)
+ self._updateInteractiveMode()
+
+ def _updateInteractiveMode(self):
+ """Update the current mode to the same if some cached data have to be
+ updated. It is the case for the color for example.
+ """
+ if self._drawingMode == 'rectangle':
+ self._activeRectMode()
+ elif self._drawingMode == 'polygon':
+ self._activePolygonMode()
+ elif self._drawingMode == 'pencil':
+ self._activePencilMode()
+
+ def _handleClearMask(self):
+ """Handle clear button clicked: reset current level mask"""
+ self._mask.clear(self.levelSpinBox.value())
+ self._mask.commit()
+
+ def _handleInvertMask(self):
+ """Invert the current mask level selection."""
+ self._mask.invert(self.levelSpinBox.value())
+ self._mask.commit()
+
+ # Handle drawing tools UI events
+
+ def _interactiveModeChanged(self, source):
+ """Handle plot interactive mode changed:
+
+ If changed from elsewhere, disable drawing tool
+ """
+ if source is not self:
+ # Do not trigger browseAction to avoid to call
+ # self.plot.setInteractiveMode
+ self.browseAction.setChecked(True)
+ self._releaseDrawingMode()
+
+ def _releaseDrawingMode(self):
+ """Release the drawing mode if is was used"""
+ if self._drawingMode is None:
+ return
+ self.plot.sigPlotSignal.disconnect(self._plotDrawEvent)
+ self._drawingMode = None
+
+ def _activeBrowseMode(self):
+ """Handle browse action mode triggered by user.
+
+ Set plot interactive mode only when
+ the user is triggering the browse action.
+ """
+ self._releaseDrawingMode()
+ self.plot.setInteractiveMode('zoom', source=self)
+ self._updateDrawingModeWidgets()
+
+ def _activeRectMode(self):
+ """Handle rect action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'rectangle'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ 'draw', shape='rectangle', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _activePolygonMode(self):
+ """Handle polygon action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'polygon'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _activePencilMode(self):
+ """Handle pencil action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'pencil'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ width = self.pencilSpinBox.value()
+ self.plot.setInteractiveMode(
+ 'draw', shape='pencil', source=self, color=color, width=width)
+ self._updateDrawingModeWidgets()
+
+ def _updateDrawingModeWidgets(self):
+ self.pencilSetting.setVisible(self._drawingMode == 'pencil')
+
+ # Handle plot drawing events
+
+ def _isMasking(self):
+ """Returns true if the tool is used for masking, else it is used for
+ unmasking.
+
+ :rtype: bool"""
+ # First draw event, use current modifiers for all draw sequence
+ doMask = (self.maskStateGroup.checkedId() == 1)
+ if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier:
+ doMask = not doMask
+ return doMask
+
+ # Handle threshold UI events
+ def _belowThresholdActionTriggered(self, triggered):
+ if triggered:
+ self.minLineEdit.setEnabled(True)
+ self.maxLineEdit.setEnabled(False)
+ self.applyMaskBtn.setEnabled(True)
+
+ def _betweenThresholdActionTriggered(self, triggered):
+ if triggered:
+ self.minLineEdit.setEnabled(True)
+ self.maxLineEdit.setEnabled(True)
+ self.applyMaskBtn.setEnabled(True)
+
+ def _aboveThresholdActionTriggered(self, triggered):
+ if triggered:
+ self.minLineEdit.setEnabled(False)
+ self.maxLineEdit.setEnabled(True)
+ self.applyMaskBtn.setEnabled(True)
+
+ def _thresholdActionGroupTriggered(self, triggeredAction):
+ """Threshold action group listener."""
+ if triggeredAction.isChecked():
+ # Uncheck other actions
+ for action in self.thresholdActionGroup.actions():
+ if action is not triggeredAction and action.isChecked():
+ action.setChecked(False)
+ else:
+ # Disable min/max edit
+ self.minLineEdit.setEnabled(False)
+ self.maxLineEdit.setEnabled(False)
+ self.applyMaskBtn.setEnabled(False)
+
+ def _maskBtnClicked(self):
+ if self.belowThresholdAction.isChecked():
+ if self.minLineEdit.text():
+ self._mask.updateBelowThreshold(self.levelSpinBox.value(),
+ float(self.minLineEdit.text()))
+ self._mask.commit()
+
+ elif self.betweenThresholdAction.isChecked():
+ if self.minLineEdit.text() and self.maxLineEdit.text():
+ min_ = float(self.minLineEdit.text())
+ max_ = float(self.maxLineEdit.text())
+ self._mask.updateBetweenThresholds(self.levelSpinBox.value(),
+ min_, max_)
+ self._mask.commit()
+
+ elif self.aboveThresholdAction.isChecked():
+ if self.maxLineEdit.text():
+ max_ = float(self.maxLineEdit.text())
+ self._mask.updateAboveThreshold(self.levelSpinBox.value(),
+ max_)
+ self._mask.commit()
+
+ def _maskNotFiniteBtnClicked(self):
+ """Handle not finite mask button clicked: mask NaNs and inf"""
+ self._mask.updateNotFinite(
+ self.levelSpinBox.value())
+ self._mask.commit()
+
+
+class BaseMaskToolsDockWidget(qt.QDockWidget):
+ """Base class for :class:`MaskToolsWidget` and
+ :class:`ScatterMaskToolsWidget`
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :paran str name: The title of this widget
+ """
+
+ sigMaskChanged = qt.Signal()
+
+ def __init__(self, parent=None, name='Mask'):
+ super(BaseMaskToolsDockWidget, self).__init__(parent)
+ self.setWindowTitle(name)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.dockLocationChanged.connect(self._dockLocationChanged)
+ self.topLevelChanged.connect(self._topLevelChanged)
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ # must be connected to self.widget().sigMaskChanged in child class
+ self.sigMaskChanged.emit()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a 2D array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.widget().getSelectionMask(copy=copy)
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ return self.widget().setSelectionMask(mask, copy=copy)
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(BaseMaskToolsDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon('image-mask'))
+ action.setToolTip("Display/hide mask tools")
+ return action
+
+ def _dockLocationChanged(self, area):
+ if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea):
+ direction = qt.QBoxLayout.TopToBottom
+ else:
+ direction = qt.QBoxLayout.LeftToRight
+ self.widget().setDirection(direction)
+
+ def _topLevelChanged(self, topLevel):
+ if topLevel:
+ self.widget().setDirection(qt.QBoxLayout.LeftToRight)
+ self.resize(self.widget().minimumSize())
+ self.adjustSize()
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/silx/gui/plot/__init__.py b/silx/gui/plot/__init__.py
new file mode 100644
index 0000000..06a24a7
--- /dev/null
+++ b/silx/gui/plot/__init__.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for plotting curves and images.
+
+The plotting API is inherited from the `PyMca <http://pymca.sourceforge.net/>`_
+plot API and is mostly compatible with it.
+
+Those widgets supports interaction (e.g., zoom, pan, selections).
+
+List of Qt widgets:
+
+.. currentmodule:: silx.gui.plot
+
+- :mod:`.PlotWidget`: A widget displaying a single plot.
+- :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools.
+- :class:`.Plot1D`: A widget with tools for curves.
+- :class:`.Plot2D`: A widget with tools for images.
+- :class:`.ImageView`: A widget with tools for images and a side histogram.
+- :class:`.StackView`: A widget with tools for a stack of images.
+
+By default, those widget are using matplotlib_.
+They can optionally use a faster OpenGL-based rendering (beta feature),
+which is enabled by setting the ``backend`` argument to ``'gl'``
+when creating the widgets (See :class:`.Plot`).
+
+.. note::
+
+ This package depends on matplotlib_.
+ The OpenGL backend further depends on
+ `PyOpenGL <http://pyopengl.sourceforge.net/>`_ and OpenGL >= 2.1.
+
+.. _matplotlib: http://matplotlib.org/
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/02/2016"
+
+
+# First of all init matplotlib and set its backend
+from .backends import _matplotlib # noqa
+
+from .PlotWidget import PlotWidget # noqa
+from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa
+from .ImageView import ImageView # noqa
+from .StackView import StackView # noqa
+
+__all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D',
+ 'StackView']
diff --git a/silx/gui/plot/_utils/__init__.py b/silx/gui/plot/_utils/__init__.py
new file mode 100644
index 0000000..355bc02
--- /dev/null
+++ b/silx/gui/plot/_utils/__init__.py
@@ -0,0 +1,104 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Miscellaneous utility functions for the Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+import numpy
+
+from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
+from .panzoom import applyZoomToPlot, applyPan
+
+
+def clipColormapLogRange(colormap):
+ """Clip colormap vmin and vmax to 1, 10 if normalization is 'log' and vmin
+ or vmax <1
+
+ :param dict colormap: the colormap for which we want to clip vmin and vmax
+ """
+ if colormap['normalization'] is 'log':
+ if colormap['vmin'] < 1. or colormap['vmax'] < 1.:
+ colormap['vmin'], colormap['vmax'] = 1., 10.
+
+
+def addMarginsToLimits(margins, isXLog, isYLog,
+ xMin, xMax, yMin, yMax, y2Min=None, y2Max=None):
+ """Returns updated limits by extending them with margins.
+
+ :param margins: The ratio of the margins to add or None for no margins.
+ :type margins: A 4-tuple of floats as
+ (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
+
+ :return: The updated limits
+ :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or
+ (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max
+ are provided.
+ """
+ if margins is not None:
+ xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins
+
+ if not isXLog:
+ xRange = xMax - xMin
+ xMin -= xMinMargin * xRange
+ xMax += xMaxMargin * xRange
+
+ elif xMin > 0. and xMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
+ xRangeLog = xMaxLog - xMinLog
+ xMin = pow(10., xMinLog - xMinMargin * xRangeLog)
+ xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog)
+
+ if not isYLog:
+ yRange = yMax - yMin
+ yMin -= yMinMargin * yRange
+ yMax += yMaxMargin * yRange
+ elif yMin > 0. and yMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
+ yRangeLog = yMaxLog - yMinLog
+ yMin = pow(10., yMinLog - yMinMargin * yRangeLog)
+ yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is not None and y2Max is not None:
+ if not isYLog:
+ yRange = y2Max - y2Min
+ y2Min -= yMinMargin * yRange
+ y2Max += yMaxMargin * yRange
+ elif y2Min > 0. and y2Max > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
+ yRangeLog = yMaxLog - yMinLog
+ y2Min = pow(10., yMinLog - yMinMargin * yRangeLog)
+ y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is None or y2Max is None:
+ return xMin, xMax, yMin, yMax
+ else:
+ return xMin, xMax, yMin, yMax, y2Min, y2Max
+
diff --git a/silx/gui/plot/_utils/panzoom.py b/silx/gui/plot/_utils/panzoom.py
new file mode 100644
index 0000000..bec31df
--- /dev/null
+++ b/silx/gui/plot/_utils/panzoom.py
@@ -0,0 +1,156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to apply pan and zoom on a Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+import math
+
+import numpy
+
+
+# Float 32 info ###############################################################
+# Using min/max value below limits of float32
+# so operation with such value (e.g., max - min) do not overflow
+
+FLOAT32_SAFE_MIN = -1e37
+FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny
+FLOAT32_SAFE_MAX = 1e37
+# TODO double support
+
+
+def scale1DRange(min_, max_, center, scale, isLog):
+ """Scale a 1D range given a scale factor and an center point.
+
+ Keeps the values in a smaller range than float32.
+
+ :param float min_: The current min value of the range.
+ :param float max_: The current max value of the range.
+ :param float center: The center of the zoom (i.e., invariant point).
+ :param float scale: The scale to use for zoom
+ :param bool isLog: Whether using log scale or not.
+ :return: The zoomed range.
+ :rtype: tuple of 2 floats: (min, max)
+ """
+ if isLog:
+ # Min and center can be < 0 when
+ # autoscale is off and switch to log scale
+ # max_ < 0 should not happen
+ min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS
+ center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS
+ max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS
+
+ if min_ == max_:
+ return min_, max_
+
+ offset = (center - min_) / (max_ - min_)
+ range_ = (max_ - min_) / scale
+ newMin = center - offset * range_
+ newMax = center + (1. - offset) * range_
+
+ if isLog:
+ # No overflow as exponent is log10 of a float32
+ newMin = pow(10., newMin)
+ newMax = pow(10., newMax)
+ newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ else:
+ newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ return newMin, newMax
+
+
+def applyZoomToPlot(plot, scaleF, center=None):
+ """Zoom in/out plot given a scale and a center point.
+
+ :param plot: The plot on which to apply zoom.
+ :param float scaleF: Scale factor of zoom.
+ :param center: (x, y) coords in pixel coordinates of the zoom center.
+ :type center: 2-tuple of float
+ """
+ xMin, xMax = plot.getGraphXLimits()
+ yMin, yMax = plot.getGraphYLimits()
+
+ if center is None:
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ cx, cy = left + width // 2, top + height // 2
+ else:
+ cx, cy = center
+
+ dataCenterPos = plot.pixelToData(cx, cy)
+ assert dataCenterPos is not None
+
+ xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF,
+ plot.isXAxisLogarithmic())
+
+ yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF,
+ plot.isYAxisLogarithmic())
+
+ dataPos = plot.pixelToData(cx, cy, axis="right")
+ assert dataPos is not None
+ y2Center = dataPos[1]
+ y2Min, y2Max = plot.getGraphYLimits(axis="right")
+ y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF,
+ plot.isYAxisLogarithmic())
+
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+
+def applyPan(min_, max_, panFactor, isLog10):
+ """Returns a new range with applied panning.
+
+ Moves the range according to panFactor.
+ If isLog10 is True, converts to log10 before moving.
+
+ :param float min_: Min value of the data range to pan.
+ :param float max_: Max value of the data range to pan.
+ Must be >= min.
+ :param float panFactor: Signed proportion of the range to use for pan.
+ :param bool isLog10: True if log10 scale, False if linear scale.
+ :return: New min and max value with pan applied.
+ :rtype: 2-tuple of float.
+ """
+ if isLog10 and min_ > 0.:
+ # Negative range and log scale can happen with matplotlib
+ logMin, logMax = math.log10(min_), math.log10(max_)
+ logOffset = panFactor * (logMax - logMin)
+ newMin = pow(10., logMin + logOffset)
+ newMax = pow(10., logMax + logOffset)
+
+ # Takes care of out-of-range values
+ if newMin > 0. and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+
+ else:
+ offset = panFactor * (max_ - min_)
+ newMin, newMax = min_ + offset, max_ + offset
+
+ # Takes care of out-of-range values
+ if newMin > - float('inf') and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+ return min_, max_
diff --git a/silx/gui/plot/_utils/setup.py b/silx/gui/plot/_utils/setup.py
new file mode 100644
index 0000000..0271745
--- /dev/null
+++ b/silx/gui/plot/_utils/setup.py
@@ -0,0 +1,42 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('_utils', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/plot/_utils/test/__init__.py b/silx/gui/plot/_utils/test/__init__.py
new file mode 100644
index 0000000..4a443ac
--- /dev/null
+++ b/silx/gui/plot/_utils/test/__init__.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
+
+
+import unittest
+
+from .test_ticklayout import suite as test_ticklayout_suite
+
+
+def suite():
+ testsuite = unittest.TestSuite()
+ testsuite.addTest(test_ticklayout_suite())
+ return testsuite
diff --git a/silx/gui/plot/_utils/test/test_ticklayout.py b/silx/gui/plot/_utils/test/test_ticklayout.py
new file mode 100644
index 0000000..8c67620
--- /dev/null
+++ b/silx/gui/plot/_utils/test/test_ticklayout.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
+
+
+import unittest
+
+from silx.test.utils import ParametricTestCase
+
+from silx.gui.plot._utils import ticklayout
+
+
+class TestTickLayout(ParametricTestCase):
+ """Test ticks layout algorithms"""
+
+ def testNiceNumbers(self):
+ """Minimalistic tests of :func:`niceNumbers`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (0.5, 10.5): (0.0, 12.0, 2.0, 0),
+ (10000., 10000.5): (10000.0, 10000.5, 0.1, 1),
+ (0.001, 0.005): (0.001, 0.005, 0.001, 3)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbers(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
+
+ def testNiceNumbersLog(self):
+ """Minimalistic tests of :func:`niceNumbersForLog10`"""
+ tests = { # (log10(min), log10(max): ref_ticks
+ (0., 3.): (0, 3, 1, 0),
+ (-3., 3): (-3, 3, 1, 0),
+ (-32., 0.): (-36, 0, 6, 0)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbersForLog10(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
+
+
+def suite():
+ testsuite = unittest.TestSuite()
+ testsuite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestTickLayout))
+ return testsuite
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/silx/gui/plot/_utils/ticklayout.py b/silx/gui/plot/_utils/ticklayout.py
new file mode 100644
index 0000000..5f4b636
--- /dev/null
+++ b/silx/gui/plot/_utils/ticklayout.py
@@ -0,0 +1,224 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements labels layout on graph axes."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
+
+
+import math
+
+
+# utils #######################################################################
+
+def numberOfDigits(tickSpacing):
+ """Returns the number of digits to display for text label.
+
+ :param float tickSpacing: Step between ticks in data space.
+ :return: Number of digits to show for labels.
+ :rtype: int
+ """
+ nfrac = int(-math.floor(math.log10(tickSpacing)))
+ if nfrac < 0:
+ nfrac = 0
+ return nfrac
+
+
+# Nice Numbers ################################################################
+
+def _niceNum(value, isRound=False):
+ expvalue = math.floor(math.log10(value))
+ frac = value/pow(10., expvalue)
+ if isRound:
+ if frac < 1.5:
+ nicefrac = 1.
+ elif frac < 3.:
+ nicefrac = 2.
+ elif frac < 7.:
+ nicefrac = 5.
+ else:
+ nicefrac = 10.
+ else:
+ if frac <= 1.:
+ nicefrac = 1.
+ elif frac <= 2.:
+ nicefrac = 2.
+ elif frac <= 5.:
+ nicefrac = 5.
+ else:
+ nicefrac = 10.
+ return nicefrac * pow(10., expvalue)
+
+
+def niceNumbers(vMin, vMax, nTicks=5):
+ """Returns tick positions.
+
+ This function implements graph labels layout using nice numbers
+ by Paul Heckbert from "Graphics Gems", Academic Press, 1990.
+ See `C code <http://tog.acm.org/resources/GraphicsGems/gems/Label.c>`_.
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ vrange = _niceNum(vMax - vMin, False)
+ spacing = _niceNum(vrange / nTicks, True)
+ graphmin = math.floor(vMin / spacing) * spacing
+ graphmax = math.ceil(vMax / spacing) * spacing
+ nfrac = numberOfDigits(spacing)
+ return graphmin, graphmax, spacing, nfrac
+
+
+def _frange(start, stop, step):
+ """range for float (including stop)."""
+ assert step >= 0.
+ while start <= stop:
+ yield start
+ start += step
+
+
+def ticks(vMin, vMax, nbTicks=5):
+ """Returns tick positions and labels using nice numbers algorithm.
+
+ This enforces ticks to be within [vMin, vMax] range.
+ It returns at least 2 ticks.
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nbTicks: The number of ticks to position
+ :returns: tick positions and corresponding text labels
+ :rtype: 2-tuple: list of float, list of string
+ """
+ start, end, step, nfrac = niceNumbers(vMin, vMax, nbTicks)
+ positions = [t for t in _frange(start, end, step) if vMin <= t <= vMax]
+
+ # Makes sure there is at least 2 ticks
+ if len(positions) < 2:
+ positions = [vMin, vMax]
+ nfrac = numberOfDigits(vMax - vMin)
+
+ # Generate labels
+ format_ = '%g' if nfrac == 0 else '%.{}f'.format(nfrac)
+ labels = [format_ % tick for tick in positions]
+ return positions, labels
+
+
+def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbers(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+# Nice Numbers for log scale ##################################################
+
+def niceNumbersForLog10(minLog, maxLog, nTicks=5):
+ """Return tick positions for logarithmic scale
+
+ :param float minLog: log10 of the min value on the axis
+ :param float maxLog: log10 of the max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple of int
+ """
+ graphminlog = math.floor(minLog)
+ graphmaxlog = math.ceil(maxLog)
+ rangelog = graphmaxlog - graphminlog
+
+ if rangelog <= nTicks:
+ spacing = 1.
+ else:
+ spacing = math.floor(rangelog / nTicks)
+
+ graphminlog = math.floor(graphminlog / spacing) * spacing
+ graphmaxlog = math.ceil(graphmaxlog / spacing) * spacing
+
+ nfrac = numberOfDigits(spacing)
+
+ return int(graphminlog), int(graphmaxlog), int(spacing), nfrac
+
+
+def niceNumbersAdaptativeForLog10(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbersForLog10(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+def computeLogSubTicks(ticks, lowBound, highBound):
+ """Return the sub ticks for the log scale for all given ticks if subtick
+ is in [lowBound, highBound]
+
+ :param ticks: log10 of the ticks
+ :param lowBound: the lower boundary of ticks
+ :param highBound: the higher boundary of ticks
+ :return: all the sub ticks contained in ticks (log10)
+ """
+ if len(ticks) < 1:
+ return []
+
+ res = []
+ for logPos in ticks:
+ dataOrigPos = logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if lowBound <= dataPos <= highBound:
+ res.append(dataPos)
+ return res
diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py
new file mode 100644
index 0000000..74f96af
--- /dev/null
+++ b/silx/gui/plot/backends/BackendBase.py
@@ -0,0 +1,474 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Base class for Plot backends.
+
+It documents the Plot backend API.
+
+This API is a simplified version of PyMca PlotBackend API.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import weakref
+
+
+# Names for setCursor
+CURSOR_DEFAULT = 'default'
+CURSOR_POINTING = 'pointing'
+CURSOR_SIZE_HOR = 'size horizontal'
+CURSOR_SIZE_VER = 'size vertical'
+CURSOR_SIZE_ALL = 'size all'
+
+
+class BackendBase(object):
+ """Class defining the API a backend of the Plot should provide."""
+
+ def __init__(self, plot, parent=None):
+ """Init.
+
+ :param Plot plot: The Plot this backend is attached to
+ :param parent: The parent widget of the plot widget.
+ """
+ self.__xLimits = 1., 100.
+ self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
+ self.__yAxisInverted = False
+ self.__keepDataAspectRatio = False
+ # Store a weakref to get access to the plot state.
+ self._setPlot(plot)
+
+ @property
+ def _plot(self):
+ """The plot this backend is attached to."""
+ if self._plotRef is None:
+ raise RuntimeError('This backend is not attached to a Plot')
+
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError('This backend is no more attached to a Plot')
+ return plot
+
+ def _setPlot(self, plot):
+ """Allow to set plot after init.
+
+ Use with caution, basically **immediately** after init.
+ """
+ self._plotRef = weakref.ref(plot)
+
+ # Add methods
+
+ def addCurve(self, x, y, legend,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror, z, selectable,
+ fill, alpha, symbolsize):
+ """Add a 1D curve given by x an y to the graph.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param str legend: The legend to be associated to the curve
+ :param color: color(s) to be used
+ :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - ' ' or '' no symbol
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param float linewidth: The width of the curve in pixels
+ :param str linestyle: Type of line::
+
+ - ' ' or '' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: numpy.ndarray or None
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: numpy.ndarray or None
+ :param int z: Layer on which to draw the cuve
+ :param bool selectable: indicate if the curve can be selected
+ :param bool fill: True to fill the curve, False otherwise
+ :param float alpha: Curve opacity, as a float in [0., 1.]
+ :param float symbolsize: Size of the symbol (if any) drawn
+ at each (x, y) position.
+ :returns: The handle used by the backend to univocally access the curve
+ """
+ return legend
+
+ def addImage(self, data, legend,
+ origin, scale, z,
+ selectable, draggable,
+ colormap, alpha):
+ """Add an image to the plot.
+
+ :param numpy.ndarray data: (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ :param str legend: The legend to be associated to the image
+ :param origin: (origin X, origin Y) of the data.
+ Default: (0., 0.)
+ :type origin: 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ Default: (1., 1.)
+ :type scale: 2-tuple of float
+ :param int z: Layer on which to draw the image
+ :param bool selectable: indicate if the image can be selected
+ :param bool draggable: indicate if the image can be moved
+ :param colormap: Dictionary describing the colormap to use.
+ Ignored if data is RGB(A).
+ :type colormap: dict or None
+ :param float alpha: Opacity of the image, as a float in range [0, 1].
+ :returns: The handle used by the backend to univocally access the image
+ """
+ return legend
+
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ """Add an item (i.e. a shape) to the plot.
+
+ :param numpy.ndarray x: The X coords of the points of the shape
+ :param numpy.ndarray y: The Y coords of the points of the shape
+ :param str legend: The legend to be associated to the item
+ :param str shape: Type of item to be drawn in
+ hline, polygon, rectangle, vline, polylines
+ :param str color: Color of the item
+ :param bool fill: True to fill the shape
+ :param bool overlay: True if item is an overlay, False otherwise
+ :param int z: Layer on which to draw the item
+ :returns: The handle used by the backend to univocally access the item
+ """
+ return legend
+
+ def addMarker(self, x, y, legend, text, color,
+ selectable, draggable,
+ symbol, constraint, overlay):
+ """Add a point, vertical line or horizontal line marker to the plot.
+
+ :param float x: Horizontal position of the marker in graph coordinates.
+ If None, the marker is a horizontal line.
+ :param float y: Vertical position of the marker in graph coordinates.
+ If None, the marker is a vertical line.
+ :param str legend: Legend associated to the marker
+ :param str text: Text associated to the marker (or None for no text)
+ :param str color: Color to be used for instance 'blue', 'b', '#FF0000'
+ :param bool selectable: indicate if the marker can be selected
+ :param bool draggable: indicate if the marker can be moved
+ :param str symbol: Symbol representing the marker.
+ Only relevant for point markers where X and Y are not None.
+ Value in:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param bool overlay: True if marker is an overlay (Default: False).
+ This allows for rendering optimization if this
+ marker is changed often.
+ :return: Handle used by the backend to univocally access the marker
+ """
+ return legend
+
+ # Remove methods
+
+ def remove(self, item):
+ """Remove an existing item from the plot.
+
+ :param item: A backend specific item handle returned by a add* method
+ """
+ pass
+
+ # Interaction methods
+
+ def setGraphCursorShape(self, cursor):
+ """Set the cursor shape.
+
+ To override in interactive backends.
+
+ :param str cursor: Name of the cursor shape or None
+ """
+ pass
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ To override in interactive backends.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in Colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array.
+ :param int linewidth: The width of the lines of the crosshair.
+ :param linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :type linestyle: None or one of the predefined styles.
+ """
+ pass
+
+ def pickItems(self, x, y):
+ """Get a list of items at a pixel position.
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :return: All picked items from back to front.
+ One dict per item,
+ with 'kind' key in 'curve', 'marker', 'image';
+ 'legend' key, the item legend.
+ and for curves, 'xdata' and 'ydata' keys storing picked
+ position on the curve.
+ :rtype: list of dict
+ """
+ return []
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ """Set the color of a curve.
+
+ :param curve: The curve handle
+ :param str color: The color to use.
+ """
+ pass
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget this backend is drawing to."""
+ return None
+
+ def postRedisplay(self):
+ """Trigger a :meth:`Plot.replot`.
+
+ Default implementation triggers a synchronous replot if plot is dirty.
+ This method should be overridden by the embedding widget in order to
+ provide an asynchronous call to replot in order to optimize the number
+ replot operations.
+ """
+ # This method can be deferred and it might happen that plot has been
+ # destroyed in between, especially with unittests
+
+ plot = self._plotRef()
+ if plot is not None and plot._getDirtyPlot():
+ plot.replot()
+
+ def replot(self):
+ """Redraw the plot."""
+ pass
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ """Save the graph to a file (or a StringIO)
+
+ :param fileName: Destination
+ :type fileName: String or StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :param int dpi: The resolution to use or None.
+ """
+ pass
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ """Set the main title of the plot.
+
+ :param str title: Title associated to the plot
+ """
+ pass
+
+ def setGraphXLabel(self, label):
+ """Set the X axis label.
+
+ :param str label: label associated to the plot bottom X axis
+ """
+ pass
+
+ def setGraphYLabel(self, label, axis):
+ """Set the left Y axis label.
+
+ :param str label: label associated to the plot left Y axis
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ pass
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value
+ :param float y2max: maximum right axis value
+ """
+ self.__xLimits = xmin, xmax
+ self.__yLimits['left'] = ymin, ymax
+ if y2min is not None and y2max is not None:
+ self.__yLimits['right'] = y2min, y2max
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self.__xLimits
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the limits of X axis.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self.__xLimits = xmin, xmax
+
+ def getGraphYLimits(self, axis):
+ """Get the graph Y (left) limits.
+
+ :param str axis: The axis for which to get the limits: left or right
+ :return: Minimum and maximum values of the Y axis
+ """
+ return self.__yLimits[axis]
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ """Set the limits of the Y axis.
+
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ self.__yLimits[axis] = ymin, ymax
+
+ # Graph axes
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the X axis scale between linear and log.
+
+ :param bool flag: If True, the bottom axis will use a log scale
+ """
+ pass
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axis scale between linear and log.
+
+ :param bool flag: If True, the left axis will use a log scale
+ """
+ pass
+
+ def setYAxisInverted(self, flag):
+ """Invert the Y axis.
+
+ :param bool flag: If True, put the vertical axis origin on the top
+ """
+ self.__yAxisInverted = bool(flag)
+
+ def isYAxisInverted(self):
+ """Return True if left Y axis is inverted, False otherwise."""
+ return self.__yAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.__keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether to keep data aspect ratio or not.
+
+ :param flag: True to respect data aspect ratio
+ :type flag: Boolean, default True
+ """
+ self.__keepDataAspectRatio = bool(flag)
+
+ def setGraphGrid(self, which):
+ """Set grid.
+
+ :param which: None to disable grid, 'major' for major grid,
+ 'both' for major and minor grid
+ """
+ pass
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ """Convert a position in data space to a position in pixels
+ in the widget.
+
+ :param float x: The X coordinate in data space.
+ :param float y: The Y coordinate in data space.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ raise NotImplementedError()
+
+ def pixelToData(self, x, y, axis, check):
+ """Convert a position in pixels in the widget to a position in
+ the data space.
+
+ :param float x: The X coordinate in pixels.
+ :param float y: The Y coordinate in pixels.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: True to check if the coordinates are in the
+ plot area.
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ raise NotImplementedError()
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ raise NotImplementedError()
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
new file mode 100644
index 0000000..f9e60d5
--- /dev/null
+++ b/silx/gui/plot/backends/BackendMatplotlib.py
@@ -0,0 +1,821 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Matplotlib Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
+__license__ = "MIT"
+__date__ = "18/01/2017"
+
+
+import logging
+
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+from ... import qt
+
+from ._matplotlib import FigureCanvasQTAgg
+import matplotlib
+from matplotlib.container import Container
+from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle, Polygon
+from matplotlib.image import AxesImage
+from matplotlib.backend_bases import MouseEvent
+from matplotlib.lines import Line2D
+from matplotlib.collections import PathCollection, LineCollection
+
+from .ModestImage import ModestImage
+from . import BackendBase
+from .. import Colors
+from .._utils import FLOAT32_MINPOS
+
+
+class BackendMatplotlib(BackendBase.BackendBase):
+ """Base class for Matplotlib backend without a FigureCanvas.
+
+ For interactive on screen plot, see :class:`BackendMatplotlibQt`.
+
+ See :class:`BackendBase.BackendBase` for public API documentation.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(BackendMatplotlib, self).__init__(plot, parent)
+
+ # matplotlib is handling keep aspect ratio at draw time
+ # When keep aspect ratio is on, and one changes the limits and
+ # ask them *before* next draw has been performed he will get the
+ # limits without applying keep aspect ratio.
+ # This attribute is used to ensure consistent values returned
+ # when getting the limits at the expense of a replot
+ self._dirtyLimits = True
+
+ self.fig = Figure()
+ self.fig.set_facecolor("w")
+
+ self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
+ self.ax2 = self.ax.twinx()
+ self.ax2.set_label("right")
+
+ # critical for picking!!!!
+ self.ax2.set_zorder(0)
+ self.ax2.set_autoscaley_on(True)
+ self.ax.set_zorder(1)
+ # this works but the figure color is left
+ if matplotlib.__version__[0] < '2':
+ self.ax.set_axis_bgcolor('none')
+ else:
+ self.ax.set_facecolor('none')
+ self.fig.sca(self.ax)
+
+ self._overlays = set()
+ self._background = None
+
+ self._colormaps = {}
+
+ self._graphCursor = tuple()
+ self.matplotlibVersion = matplotlib.__version__
+
+ self.setGraphXLimits(0., 100.)
+ self.setGraphYLimits(0., 100., axis='right')
+ self.setGraphYLimits(0., 100., axis='left')
+
+ self._enableAxis('right', False)
+
+ # Add methods
+
+ def addCurve(self, x, y, legend,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror, z, selectable,
+ fill, alpha, symbolsize):
+ for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
+ yaxis, z, selectable, fill, alpha, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float) / 255.
+
+ if yaxis == "right":
+ axes = self.ax2
+ self._enableAxis("right", True)
+ else:
+ axes = self.ax
+
+ picker = 3 if selectable else None
+
+ artists = [] # All the artists composing the curve
+
+ # First add errorbars if any so they are behind the curve
+ if xerror is not None or yerror is not None:
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ errorbarColor = 'k'
+ else:
+ errorbarColor = color
+
+ # On Debian 7 at least, Nx1 array yerr does not seems supported
+ if (yerror is not None and yerror.ndim == 2 and
+ yerror.shape[1] == 1 and len(x) != 1):
+ yerror = numpy.ravel(yerror)
+
+ errorbars = axes.errorbar(x, y, label=legend,
+ xerr=xerror, yerr=yerror,
+ linestyle=' ', color=errorbarColor)
+ artists += list(errorbars.get_children())
+
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ # scatter plot
+ if color.dtype not in [numpy.float32, numpy.float]:
+ actualColor = color / 255.
+ else:
+ actualColor = color
+
+ if linestyle not in ["", " ", None]:
+ # scatter plot with an actual line ...
+ # we need to assign a color ...
+ curveList = axes.plot(x, y, label=legend,
+ linestyle=linestyle,
+ color=actualColor[0],
+ linewidth=linewidth,
+ picker=picker,
+ marker=None)
+ artists += list(curveList)
+
+ scatter = axes.scatter(x, y,
+ label=legend,
+ color=actualColor,
+ marker=symbol,
+ picker=picker,
+ s=symbolsize)
+ artists.append(scatter)
+
+ if fill:
+ artists.append(axes.fill_between(
+ x, FLOAT32_MINPOS, y, facecolor=actualColor[0], linestyle=''))
+
+ else: # Curve
+ curveList = axes.plot(x, y,
+ label=legend,
+ linestyle=linestyle,
+ color=color,
+ linewidth=linewidth,
+ marker=symbol,
+ picker=picker,
+ markersize=symbolsize)
+ artists += list(curveList)
+
+ if fill:
+ artists.append(
+ axes.fill_between(x, FLOAT32_MINPOS, y, facecolor=color))
+
+ for artist in artists:
+ artist.set_zorder(z)
+ if alpha < 1:
+ artist.set_alpha(alpha)
+
+ return Container(artists)
+
+ def addImage(self, data, legend,
+ origin, scale, z,
+ selectable, draggable,
+ colormap, alpha):
+ # Non-uniform image
+ # http://wiki.scipy.org/Cookbook/Histograms
+ # Non-linear axes
+ # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
+ for parameter in (data, legend, origin, scale, z,
+ selectable, draggable):
+ assert parameter is not None
+
+ origin = float(origin[0]), float(origin[1])
+ scale = float(scale[0]), float(scale[1])
+ height, width = data.shape[0:2]
+
+ picker = (selectable or draggable)
+
+ # Debian 7 specific support
+ # No transparent colormap with matplotlib < 1.2.0
+ # Add support for transparent colormap for uint8 data with
+ # colormap with 256 colors, linear norm, [0, 255] range
+ if matplotlib.__version__ < '1.2.0':
+ if (len(data.shape) == 2 and colormap['name'] is None and
+ 'colors' in colormap):
+ colors = numpy.array(colormap['colors'], copy=False)
+ if (colors.shape[-1] == 4 and
+ not numpy.all(numpy.equal(colors[3], 255))):
+ # This is a transparent colormap
+ if (colors.shape == (256, 4) and
+ colormap['normalization'] == 'linear' and
+ not colormap['autoscale'] and
+ colormap['vmin'] == 0 and
+ colormap['vmax'] == 255 and
+ data.dtype == numpy.uint8):
+ # Supported case, convert data to RGBA
+ data = colors[data.reshape(-1)].reshape(
+ data.shape + (4,))
+ else:
+ _logger.warning(
+ 'matplotlib %s does not support transparent '
+ 'colormap.', matplotlib.__version__)
+
+ if ((height * width) > 5.0e5 and
+ origin == (0., 0.) and scale == (1., 1.)):
+ imageClass = ModestImage
+ else:
+ imageClass = AxesImage
+
+ # the normalization can be a source of time waste
+ # Two possibilities, we receive data or a ready to show image
+ if len(data.shape) == 3: # RGBA image
+ image = imageClass(self.ax,
+ label="__IMAGE__" + legend,
+ interpolation='nearest',
+ picker=picker,
+ zorder=z,
+ origin='lower')
+
+ else:
+ # Convert colormap argument to matplotlib colormap
+ scalarMappable = Colors.getMPLScalarMappable(colormap, data)
+
+ # try as data
+ image = imageClass(self.ax,
+ label="__IMAGE__" + legend,
+ interpolation='nearest',
+ cmap=scalarMappable.cmap,
+ picker=picker,
+ zorder=z,
+ norm=scalarMappable.norm,
+ origin='lower')
+ if alpha < 1:
+ image.set_alpha(alpha)
+
+ # Set image extent
+ xmin = origin[0]
+ xmax = xmin + scale[0] * width
+ if scale[0] < 0.:
+ xmin, xmax = xmax, xmin
+
+ ymin = origin[1]
+ ymax = ymin + scale[1] * height
+ if scale[1] < 0.:
+ ymin, ymax = ymax, ymin
+
+ image.set_extent((xmin, xmax, ymin, ymax))
+
+ # Set image data
+ if scale[0] < 0. or scale[1] < 0.:
+ # For negative scale, step by -1
+ xstep = 1 if scale[0] >= 0. else -1
+ ystep = 1 if scale[1] >= 0. else -1
+ data = data[::ystep, ::xstep]
+
+ image.set_data(data)
+
+ self.ax.add_artist(image)
+
+ return image
+
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ xView = numpy.array(x, copy=False)
+ yView = numpy.array(y, copy=False)
+
+ if shape == "line":
+ item = self.ax.plot(x, y, label=legend, color=color,
+ linestyle='-', marker=None)[0]
+
+ elif shape == "hline":
+ if hasattr(y, "__len__"):
+ y = y[-1]
+ item = self.ax.axhline(y, label=legend, color=color)
+
+ elif shape == "vline":
+ if hasattr(x, "__len__"):
+ x = x[-1]
+ item = self.ax.axvline(x, label=legend, color=color)
+
+ elif shape == 'rectangle':
+ xMin = numpy.nanmin(xView)
+ xMax = numpy.nanmax(xView)
+ yMin = numpy.nanmin(yView)
+ yMax = numpy.nanmax(yView)
+ w = xMax - xMin
+ h = yMax - yMin
+ item = Rectangle(xy=(xMin, yMin),
+ width=w,
+ height=h,
+ fill=False,
+ color=color)
+ if fill:
+ item.set_hatch('.')
+
+ self.ax.add_patch(item)
+
+ elif shape in ('polygon', 'polylines'):
+ xView = xView.reshape(1, -1)
+ yView = yView.reshape(1, -1)
+ item = Polygon(numpy.vstack((xView, yView)).T,
+ closed=(shape == 'polygon'),
+ fill=False,
+ label=legend,
+ color=color)
+ if fill and shape == 'polygon':
+ item.set_hatch('/')
+
+ self.ax.add_patch(item)
+
+ else:
+ raise NotImplementedError("Unsupported item shape %s" % shape)
+
+ item.set_zorder(z)
+
+ if overlay:
+ item.set_animated(True)
+ self._overlays.add(item)
+
+ return item
+
+ def addMarker(self, x, y, legend, text, color,
+ selectable, draggable,
+ symbol, constraint, overlay):
+ legend = "__MARKER__" + legend
+
+ if x is not None and y is not None:
+ line = self.ax.plot(x, y, label=legend,
+ linestyle=" ",
+ color=color,
+ marker=symbol,
+ markersize=10.)[-1]
+
+ if text is not None:
+ xtmp, ytmp = self.ax.transData.transform_point((x, y))
+ inv = self.ax.transData.inverted()
+ xtmp, ytmp = inv.transform_point((xtmp, ytmp))
+
+ if symbol is None:
+ valign = 'baseline'
+ else:
+ valign = 'top'
+ text = " " + text
+
+ line._infoText = self.ax.text(x, ytmp, text,
+ color=color,
+ horizontalalignment='left',
+ verticalalignment=valign)
+
+ elif x is not None:
+ line = self.ax.axvline(x, label=legend, color=color)
+ if text is not None:
+ text = " " + text
+ ymin, ymax = self.getGraphYLimits(axis='left')
+ delta = abs(ymax - ymin)
+ if ymin > ymax:
+ ymax = ymin
+ ymax -= 0.005 * delta
+ line._infoText = self.ax.text(x, ymax, text,
+ color=color,
+ horizontalalignment='left',
+ verticalalignment='top')
+
+ elif y is not None:
+ line = self.ax.axhline(y, label=legend, color=color)
+
+ if text is not None:
+ text = " " + text
+ xmin, xmax = self.getGraphXLimits()
+ delta = abs(xmax - xmin)
+ if xmin > xmax:
+ xmax = xmin
+ xmax -= 0.005 * delta
+ line._infoText = self.ax.text(xmax, y, text,
+ color=color,
+ horizontalalignment='right',
+ verticalalignment='top')
+
+ else:
+ raise RuntimeError('A marker must at least have one coordinate')
+
+ if selectable or draggable:
+ line.set_picker(5)
+
+ if overlay:
+ line.set_animated(True)
+ self._overlays.add(line)
+
+ return line
+
+ # Remove methods
+
+ def remove(self, item):
+ # Warning: It also needs to remove extra stuff if added as for markers
+ if hasattr(item, "_infoText"): # For markers text
+ item._infoText.remove()
+ item._infoText = None
+ self._overlays.discard(item)
+ item.remove()
+
+ # Interaction methods
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if flag:
+ lineh = self.ax.axhline(
+ self.ax.get_ybound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ lineh.set_animated(True)
+
+ linev = self.ax.axvline(
+ self.ax.get_xbound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ linev.set_animated(True)
+
+ self._graphCursor = lineh, linev
+ else:
+ if self._graphCursor is not None:
+ lineh, linev = self._graphCursor
+ lineh.remove()
+ linev.remove()
+ self._graphCursor = tuple()
+
+ # Active curve
+
+ def setCurveColor(self, curve, color):
+ # Store Line2D and PathCollection
+ for artist in curve.get_children():
+ if isinstance(artist, (Line2D, LineCollection)):
+ artist.set_color(color)
+ elif isinstance(artist, PathCollection):
+ artist.set_facecolors(color)
+ artist.set_edgecolors(color)
+ else:
+ _logger.warning(
+ 'setActiveCurve ignoring artist %s', str(artist))
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self.fig.canvas
+
+ def _enableAxis(self, axis, flag=True):
+ """Show/hide Y axis
+
+ :param str axis: Axis name: 'left' or 'right'
+ :param bool flag: Default, True
+ """
+ assert axis in ('right', 'left')
+ axes = self.ax2 if axis == 'right' else self.ax
+ axes.get_yaxis().set_visible(flag)
+
+ def replot(self):
+ """Do not perform rendering.
+
+ Override in subclass to actually draw something.
+ """
+ # TODO images, markers? scatter plot? move in remove?
+ # Right Y axis only support curve for now
+ # Hide right Y axis if no line is present
+ self._dirtyLimits = False
+ if not self.ax2.lines:
+ self._enableAxis('right', False)
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ # fileName can be also a StringIO or file instance
+ if dpi is not None:
+ self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
+ else:
+ self.fig.savefig(fileName, format=fileFormat)
+ self._plot._setDirtyPlot()
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self.ax.set_title(title)
+
+ def setGraphXLabel(self, label):
+ self.ax.set_xlabel(label)
+
+ def setGraphYLabel(self, label, axis):
+ axes = self.ax if axis == 'left' else self.ax2
+ axes.set_ylabel(label)
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ # Let matplotlib taking care of keep aspect ratio if any
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+
+ if y2min is not None and y2max is not None:
+ if not self.isYAxisInverted():
+ self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
+ else:
+ self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))
+
+ if not self.isYAxisInverted():
+ self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
+ else:
+ self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))
+
+ def getGraphXLimits(self):
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.replot() # makes sure we get the right limits
+ return self.ax.get_xbound()
+
+ def setGraphXLimits(self, xmin, xmax):
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+
+ def getGraphYLimits(self, axis):
+ assert axis in ('left', 'right')
+ ax = self.ax2 if axis == 'right' else self.ax
+
+ if not ax.get_visible():
+ return None
+
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.replot() # makes sure we get the right limits
+
+ return ax.get_ybound()
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ ax = self.ax2 if axis == 'right' else self.ax
+ if ymax < ymin:
+ ymin, ymax = ymax, ymin
+ self._dirtyLimits = True
+
+ if self.isKeepDataAspectRatio():
+ # matplotlib keeps limits of shared axis when keeping aspect ratio
+ # So x limits are kept when changing y limits....
+ # Change x limits first by taking into account aspect ratio
+ # and then change y limits.. so matplotlib does not need
+ # to make change (to y) to keep aspect ratio
+ xmin, xmax = ax.get_xbound()
+ curYMin, curYMax = ax.get_ybound()
+
+ newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
+ xcenter = 0.5 * (xmin + xmax)
+ ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)
+
+ if not self.isYAxisInverted():
+ ax.set_ylim(ymin, ymax)
+ else:
+ ax.set_ylim(ymax, ymin)
+
+ # Graph axes
+
+ def setXAxisLogarithmic(self, flag):
+ self.ax2.set_xscale('log' if flag else 'linear')
+ self.ax.set_xscale('log' if flag else 'linear')
+
+ def setYAxisLogarithmic(self, flag):
+ self.ax2.set_yscale('log' if flag else 'linear')
+ self.ax.set_yscale('log' if flag else 'linear')
+
+ def setYAxisInverted(self, flag):
+ if self.ax.yaxis_inverted() != bool(flag):
+ self.ax.invert_yaxis()
+
+ def isYAxisInverted(self):
+ return self.ax.yaxis_inverted()
+
+ def isKeepDataAspectRatio(self):
+ return self.ax.get_aspect() in (1.0, 'equal')
+
+ def setKeepDataAspectRatio(self, flag):
+ self.ax.set_aspect(1.0 if flag else 'auto')
+ self.ax2.set_aspect(1.0 if flag else 'auto')
+
+ def setGraphGrid(self, which):
+ self.ax.grid(False, which='both') # Disable all grid first
+ if which is not None:
+ self.ax.grid(True, which=which)
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+
+ pixels = ax.transData.transform_point((x, y))
+ xPixel, yPixel = pixels.T
+ return xPixel, yPixel
+
+ def pixelToData(self, x, y, axis, check):
+ ax = self.ax2 if axis == "right" else self.ax
+
+ inv = ax.transData.inverted()
+ x, y = inv.transform_point((x, y))
+
+ if check:
+ xmin, xmax = self.getGraphXLimits()
+ ymin, ymax = self.getGraphYLimits(axis=axis)
+
+ if x > xmax or x < xmin or y > ymax or y < ymin:
+ return None # (x, y) is out of plot area
+
+ return x, y
+
+ def getPlotBoundsInPixels(self):
+ bbox = self.ax.get_window_extent().transformed(
+ self.fig.dpi_scale_trans.inverted())
+ dpi = self.fig.dpi
+ # Warning this is not returning int...
+ return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi,
+ bbox.bounds[2] * dpi, bbox.bounds[3] * dpi)
+
+
+class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
+ """QWidget matplotlib backend using a QtAgg canvas.
+
+ It adds fast overlay drawing and mouse event management.
+ """
+
+ _sigPostRedisplay = qt.Signal()
+ """Signal handling automatic asynchronous replot"""
+
+ def __init__(self, plot, parent=None):
+ self._insideResizeEventMethod = False
+
+ BackendMatplotlib.__init__(self, plot, parent)
+ FigureCanvasQTAgg.__init__(self, self.fig)
+ self.setParent(parent)
+
+ FigureCanvasQTAgg.setSizePolicy(
+ self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ FigureCanvasQTAgg.updateGeometry(self)
+
+ # Make postRedisplay asynchronous using Qt signal
+ self._sigPostRedisplay.connect(
+ super(BackendMatplotlibQt, self).postRedisplay,
+ qt.Qt.QueuedConnection)
+
+ self._picked = None
+
+ self.mpl_connect('button_press_event', self._onMousePress)
+ self.mpl_connect('button_release_event', self._onMouseRelease)
+ self.mpl_connect('motion_notify_event', self._onMouseMove)
+ self.mpl_connect('scroll_event', self._onMouseWheel)
+
+ def postRedisplay(self):
+ self._sigPostRedisplay.emit()
+
+ # Mouse event forwarding
+
+ _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
+
+ def _onMousePress(self, event):
+ self._plot.onMousePress(
+ event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button])
+
+ def _onMouseMove(self, event):
+ if self._graphCursor:
+ lineh, linev = self._graphCursor
+ if event.inaxes != self.ax and lineh.get_visible():
+ lineh.set_visible(False)
+ linev.set_visible(False)
+ self._plot._setDirtyPlot(overlayOnly=True)
+ else:
+ linev.set_visible(True)
+ linev.set_xdata((event.xdata, event.xdata))
+ lineh.set_visible(True)
+ lineh.set_ydata((event.ydata, event.ydata))
+ self._plot._setDirtyPlot(overlayOnly=True)
+ # onMouseMove must trigger replot if dirty flag is raised
+
+ self._plot.onMouseMove(event.x, event.y)
+
+ def _onMouseRelease(self, event):
+ self._plot.onMouseRelease(
+ event.x, event.y, self._MPL_TO_PLOT_BUTTONS[event.button])
+
+ def _onMouseWheel(self, event):
+ self._plot.onMouseWheel(event.x, event.y, event.step)
+
+ def leaveEvent(self, event):
+ """QWidget event handler"""
+ self._plot.onMouseLeaveWidget()
+
+ # picking
+
+ def _onPick(self, event):
+ # TODO not very nice and fragile, find a better way?
+ # Make a selection according to kind
+ if self._picked is None:
+ _logger.error('Internal picking error')
+ return
+
+ label = event.artist.get_label()
+ if label.startswith('__MARKER__'):
+ self._picked.append({'kind': 'marker', 'legend': label[10:]})
+
+ elif label.startswith('__IMAGE__'):
+ self._picked.append({'kind': 'image', 'legend': label[9:]})
+
+ else: # it's a curve, item have no picker for now
+ if isinstance(event.artist, PathCollection):
+ data = event.artist.get_offsets()[event.ind, :]
+ xdata, ydata = data[:, 0], data[:, 1]
+ elif isinstance(event.artist, Line2D):
+ xdata = event.artist.get_xdata()[event.ind]
+ ydata = event.artist.get_ydata()[event.ind]
+ else:
+ _logger.info('Unsupported artist, ignored')
+ return
+
+ self._picked.append({'kind': 'curve', 'legend': label,
+ 'xdata': xdata, 'ydata': ydata})
+
+ def pickItems(self, x, y):
+ self._picked = []
+
+ # Weird way to do an explicit picking: Simulate a button press event
+ mouseEvent = MouseEvent('button_press_event', self, x, y)
+ cid = self.mpl_connect('pick_event', self._onPick)
+ self.fig.pick(mouseEvent)
+ self.mpl_disconnect(cid)
+ picked = self._picked
+ self._picked = None
+
+ return picked
+
+ # replot control
+
+ def resizeEvent(self, event):
+ self._insideResizeEventMethod = True
+ # Need to dirty the whole plot on resize.
+ self._plot._setDirtyPlot()
+ FigureCanvasQTAgg.resizeEvent(self, event)
+ self._insideResizeEventMethod = False
+
+ def draw(self):
+ """Override canvas draw method to support faster draw of overlays."""
+ if self._plot._getDirtyPlot(): # Need a full redraw
+ FigureCanvasQTAgg.draw(self)
+ self._background = None # Any saved background is dirty
+
+ if (self._overlays or self._graphCursor or
+ self._plot._getDirtyPlot() == 'overlay'):
+ # There are overlays or crosshair, or they is just no more overlays
+
+ # Specific case: called from resizeEvent:
+ # avoid store/restore background, just draw the overlay
+ if not self._insideResizeEventMethod:
+ if self._background is None: # First store the background
+ self._background = self.copy_from_bbox(self.fig.bbox)
+
+ self.restore_region(self._background)
+
+ # This assume that items are only on left/bottom Axes
+ for item in self._overlays:
+ self.ax.draw_artist(item)
+
+ for item in self._graphCursor:
+ self.ax.draw_artist(item)
+
+ self.blit(self.fig.bbox)
+
+ def replot(self):
+ BackendMatplotlib.replot(self)
+ self.draw()
+
+ # cursor
+
+ _QT_CURSORS = {
+ None: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ cursor = self._QT_CURSORS[cursor]
+
+ FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
new file mode 100644
index 0000000..bc10eca
--- /dev/null
+++ b/silx/gui/plot/backends/BackendOpenGL.py
@@ -0,0 +1,1631 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""OpenGL Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+from collections import OrderedDict, namedtuple
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from .._utils import FLOAT32_MINPOS
+from . import BackendBase
+from .. import Colors
+from ... import qt
+
+from ..._glutils import gl
+from ... import _glutils as glu
+from .glutils import (
+ GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D,
+ mat4Ortho, mat4Identity,
+ LEFT, RIGHT, BOTTOM, TOP,
+ Text2D, Shape2D)
+from .glutils.PlotImageFile import saveImageToFile
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO idea: BackendQtMixIn class to share code between mpl and gl
+# TODO check if OpenGL is available
+# TODO make an off-screen mesa backend
+
+# Bounds ######################################################################
+
+class Range(namedtuple('Range', ('min_', 'max_'))):
+ """Describes a 1D range"""
+
+ @property
+ def range_(self):
+ return self.max_ - self.min_
+
+ @property
+ def center(self):
+ return 0.5 * (self.min_ + self.max_)
+
+
+class Bounds(object):
+ """Describes plot bounds with 2 y axis"""
+
+ def __init__(self, xMin, xMax, yMin, yMax, y2Min, y2Max):
+ self._xAxis = Range(xMin, xMax)
+ self._yAxis = Range(yMin, yMax)
+ self._y2Axis = Range(y2Min, y2Max)
+
+ def __repr__(self):
+ return "x: %s, y: %s, y2: %s" % (repr(self._xAxis),
+ repr(self._yAxis),
+ repr(self._y2Axis))
+
+ @property
+ def xAxis(self):
+ return self._xAxis
+
+ @property
+ def yAxis(self):
+ return self._yAxis
+
+ @property
+ def y2Axis(self):
+ return self._y2Axis
+
+
+# Content #####################################################################
+
+class PlotDataContent(object):
+ """Manage plot data content: images and curves.
+
+ This class is only meant to work with _OpenGLPlotCanvas.
+ """
+
+ _PRIMITIVE_TYPES = 'curve', 'image'
+
+ def __init__(self):
+ self._primitives = OrderedDict() # For images and curves
+
+ def add(self, primitive):
+ """Add a curve or image to the content dictionary.
+
+ This function generates the key in the dict from the primitive.
+
+ :param primitive: The primitive to add.
+ :type primitive: Instance of GLPlotCurve2D, GLPlotColormap,
+ GLPlotRGBAImage.
+ """
+ if isinstance(primitive, GLPlotCurve2D):
+ primitiveType = 'curve'
+ elif isinstance(primitive, (GLPlotColormap, GLPlotRGBAImage)):
+ primitiveType = 'image'
+ else:
+ raise RuntimeError('Unsupported object type: %s', primitive)
+
+ key = primitiveType, primitive.info['legend']
+ self._primitives[key] = primitive
+
+ def get(self, primitiveType, legend):
+ """Get the corresponding primitive of given type with given legend.
+
+ :param str primitiveType: Type of primitive ('curve' or 'image').
+ :param str legend: The legend of the primitive to retrieve.
+ :return: The corresponding curve or None if no such curve.
+ """
+ assert primitiveType in self._PRIMITIVE_TYPES
+ return self._primitives.get((primitiveType, legend))
+
+ def pop(self, primitiveType, key):
+ """Pop the corresponding curve or return None if no such curve.
+
+ :param str primitiveType:
+ :param str key:
+ :return:
+ """
+ assert primitiveType in self._PRIMITIVE_TYPES
+ return self._primitives.pop((primitiveType, key), None)
+
+ def zOrderedPrimitives(self, reverse=False):
+ """List of primitives sorted according to their z order.
+
+ It is a stable sort (as sorted):
+ Original order is preserved when key is the same.
+
+ :param bool reverse: Ascending (True, default) or descending (False).
+ """
+ return sorted(self._primitives.values(),
+ key=lambda primitive: primitive.info['zOrder'],
+ reverse=reverse)
+
+ def primitives(self):
+ """Iterator over all primitives."""
+ return self._primitives.values()
+
+ def primitiveKeys(self, primitiveType):
+ """Iterator over primitives of a specific type."""
+ assert primitiveType in self._PRIMITIVE_TYPES
+ for type_, key in self._primitives.keys():
+ if type_ == primitiveType:
+ yield key
+
+ def getBounds(self, xPositive=False, yPositive=False):
+ """Bounds of the data.
+
+ Can return strictly positive bounds (for log scale).
+ In this case, curves are clipped to their smaller positive value
+ and images with negative min are ignored.
+
+ :param bool xPositive: True to get strictly positive range.
+ :param bool yPositive: True to get strictly positive range.
+ :return: The range of data for x, y and y2, or default (1., 100.)
+ if no range found for one dimension.
+ :rtype: Bounds
+ """
+ xMin, yMin, y2Min = float('inf'), float('inf'), float('inf')
+ xMax = 0. if xPositive else -float('inf')
+ if yPositive:
+ yMax, y2Max = 0., 0.
+ else:
+ yMax, y2Max = -float('inf'), -float('inf')
+
+ for item in self._primitives.values():
+ # To support curve <= 0. and log and bypass images:
+ # If positive only, uses x|yMinPos if available
+ # and bypass other data with negative min bounds
+ if xPositive:
+ itemXMin = getattr(item, 'xMinPos', item.xMin)
+ if itemXMin is None or itemXMin < FLOAT32_MINPOS:
+ continue
+ else:
+ itemXMin = item.xMin
+
+ if yPositive:
+ itemYMin = getattr(item, 'yMinPos', item.yMin)
+ if itemYMin is None or itemYMin < FLOAT32_MINPOS:
+ continue
+ else:
+ itemYMin = item.yMin
+
+ if itemXMin < xMin:
+ xMin = itemXMin
+ if item.xMax > xMax:
+ xMax = item.xMax
+
+ if item.info.get('yAxis') == 'right':
+ if itemYMin < y2Min:
+ y2Min = itemYMin
+ if item.yMax > y2Max:
+ y2Max = item.yMax
+ else:
+ if itemYMin < yMin:
+ yMin = itemYMin
+ if item.yMax > yMax:
+ yMax = item.yMax
+
+ # One of the limit has not been updated, return default range
+ if xMin >= xMax:
+ xMin, xMax = 1., 100.
+ if yMin >= yMax:
+ yMin, yMax = 1., 100.
+ if y2Min >= y2Max:
+ y2Min, y2Max = 1., 100.
+
+ return Bounds(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+
+# shaders #####################################################################
+
+_baseVertShd = """
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform bvec2 isLog;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec2 posTransformed = position;
+ if (isLog.x) {
+ posTransformed.x = oneOverLog10 * log(position.x);
+ }
+ if (isLog.y) {
+ posTransformed.y = oneOverLog10 * log(position.y);
+ }
+ gl_Position = matrix * vec4(posTransformed, 0.0, 1.0);
+ }
+ """
+
+_baseFragShd = """
+ uniform vec4 color;
+ uniform int hatchStep;
+ uniform float tickLen;
+
+ void main(void) {
+ if (tickLen != 0.) {
+ if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ } else if (hatchStep == 0 ||
+ mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+
+_texVertShd = """
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """
+
+_texFragShd = """
+ uniform sampler2D tex;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ }
+ """
+
+
+# BackendOpenGL ###############################################################
+
+_current_context = None
+
+
+def _getContext():
+ assert _current_context is not None
+ return _current_context
+
+
+class BackendOpenGL(BackendBase.BackendBase, qt.QGLWidget):
+ """OpenGL-based Plot backend.
+
+ WARNINGS:
+ Unless stated otherwise, this API is NOT thread-safe and MUST be
+ called from the main thread.
+ When numpy arrays are passed as arguments to the API (through
+ :func:`addCurve` and :func:`addImage`), they are copied only if
+ required.
+ So, the caller should not modify these arrays afterwards.
+ """
+
+ _sigPostRedisplay = qt.Signal()
+ """Signal handling automatic asynchronous replot"""
+
+ def __init__(self, plot, parent=None):
+ qt.QGLWidget.__init__(self, parent)
+ BackendBase.BackendBase.__init__(self, plot, parent)
+
+ self.matScreenProj = mat4Identity()
+
+ self._progBase = glu.Program(
+ _baseVertShd, _baseFragShd, attrib0='position')
+ self._progTex = glu.Program(
+ _texVertShd, _texFragShd, attrib0='position')
+ self._plotFBOs = {}
+
+ self._keepDataAspectRatio = False
+
+ self._devicePixelRatio = 1.0
+
+ self._crosshairCursor = None
+ self._mousePosInPixels = None
+
+ self._markers = OrderedDict()
+ self._items = OrderedDict()
+ self._plotContent = PlotDataContent() # For images and curves
+ self._selectionAreas = OrderedDict()
+ self._glGarbageCollector = []
+
+ self._plotFrame = GLPlotFrame2D(
+ margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50})
+
+ # Make postRedisplay asynchronous using Qt signal
+ self._sigPostRedisplay.connect(
+ super(BackendOpenGL, self).postRedisplay,
+ qt.Qt.QueuedConnection)
+
+ # TODO is this needed? move it Plot?
+ self.setGraphXLimits(0., 100.)
+ self.setGraphYLimits(0., 100., axis='right')
+ self.setGraphYLimits(0., 100., axis='left')
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ # QWidget
+
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def sizeHint(self):
+ return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend
+
+ def mousePressEvent(self, event):
+ xPixel = event.x() * self._devicePixelRatio
+ yPixel = event.y() * self._devicePixelRatio
+ btn = self._MOUSE_BTNS[event.button()]
+ self._plot.onMousePress(xPixel, yPixel, btn)
+ event.accept()
+
+ def mouseMoveEvent(self, event):
+ xPixel = event.x() * self._devicePixelRatio
+ yPixel = event.y() * self._devicePixelRatio
+
+ # Handle crosshair
+ inXPixel, inYPixel = self._mouseInPlotArea(xPixel, yPixel)
+ isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+
+ previousMousePosInPixels = self._mousePosInPixels
+ self._mousePosInPixels = (xPixel, yPixel) if isCursorInPlot else None
+ if (self._crosshairCursor is not None and
+ previousMousePosInPixels != self._crosshairCursor):
+ # Avoid replot when cursor remains outside plot area
+ self._plot._setDirtyPlot(overlayOnly=True)
+
+ self._plot.onMouseMove(xPixel, yPixel)
+ event.accept()
+
+ def mouseReleaseEvent(self, event):
+ xPixel = event.x() * self._devicePixelRatio
+ yPixel = event.y() * self._devicePixelRatio
+
+ btn = self._MOUSE_BTNS[event.button()]
+ self._plot.onMouseRelease(xPixel, yPixel, btn)
+ event.accept()
+
+ def wheelEvent(self, event):
+ xPixel = event.x() * self._devicePixelRatio
+ yPixel = event.y() * self._devicePixelRatio
+
+ if hasattr(event, 'angleDelta'): # Qt 5
+ delta = event.angleDelta().y()
+ else: # Qt 4 support
+ delta = event.delta()
+ angleInDegrees = delta / 8.
+ self._plot.onMouseWheel(xPixel, yPixel, angleInDegrees)
+ event.accept()
+
+ def leaveEvent(self, _):
+ self._plot.onMouseLeaveWidget()
+
+ # QGLWidget API
+
+ @staticmethod
+ def _setBlendFuncGL():
+ # glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
+ gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA,
+ gl.GL_ONE_MINUS_SRC_ALPHA,
+ gl.GL_ONE,
+ gl.GL_ONE)
+
+ def initializeGL(self):
+ gl.testGL()
+
+ gl.glClearColor(1., 1., 1., 1.)
+ gl.glClearStencil(0)
+
+ gl.glEnable(gl.GL_BLEND)
+ self._setBlendFuncGL()
+
+ # For lines
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ # For points
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def _paintDirectGL(self):
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+ self._renderMarkersGL()
+ self._renderOverlayGL()
+
+ def _paintFBOGL(self):
+ context = glu.getGLContext()
+ plotFBOTex = self._plotFBOs.get(context)
+ if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or
+ plotFBOTex is None):
+ self._plotVertices = numpy.array(((-1., -1., 0., 0.),
+ (1., -1., 1., 0.),
+ (-1., 1., 0., 1.),
+ (1., 1., 1., 1.)),
+ dtype=numpy.float32)
+ if plotFBOTex is None or \
+ plotFBOTex.shape[1] != self._plotFrame.size[0] or \
+ plotFBOTex.shape[0] != self._plotFrame.size[1]:
+ if plotFBOTex is not None:
+ plotFBOTex.discard()
+ plotFBOTex = glu.FramebufferTexture(
+ gl.GL_RGBA,
+ shape=(self._plotFrame.size[1],
+ self._plotFrame.size[0]),
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ self._plotFBOs[context] = plotFBOTex
+
+ with plotFBOTex:
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+
+ # Render plot in screen coords
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ self._progTex.use()
+ texUnit = 0
+
+ gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
+ gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat4Identity())
+
+ stride = self._plotVertices.shape[-1] * self._plotVertices.itemsize
+ gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
+ gl.glVertexAttribPointer(self._progTex.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, self._plotVertices)
+
+ texCoordsPtr = c_void_p(self._plotVertices.ctypes.data +
+ 2 * self._plotVertices.itemsize) # Better way?
+ gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords'])
+ gl.glVertexAttribPointer(self._progTex.attributes['texCoords'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, texCoordsPtr)
+
+ with plotFBOTex.texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices))
+
+ self._renderMarkersGL()
+ self._renderOverlayGL()
+
+ def paintGL(self):
+ global _current_context
+ _current_context = self.context()
+
+ glu.setGLContextGetter(_getContext)
+
+ if hasattr(self, 'windowHandle'): # Qt 5
+ devicePixelRatio = self.windowHandle().devicePixelRatio()
+ if devicePixelRatio != self._devicePixelRatio:
+ self._devicePixelRatio = devicePixelRatio
+ self.resizeGL(int(self.width() * devicePixelRatio),
+ int(self.height() * devicePixelRatio))
+
+ # Release OpenGL resources
+ for item in self._glGarbageCollector:
+ item.discard()
+ self._glGarbageCollector = []
+
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+
+ # Check if window is large enough
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth <= 2 or plotHeight <= 2:
+ return
+
+ # self._paintDirectGL()
+ self._paintFBOGL()
+
+ glu.setGLContextGetter()
+ _current_context = None
+
+ def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset):
+ """Generates the vertices and label for a line marker.
+
+ :param dict marker: Description of a line marker
+ :param int pixelOffset: Offset of text from borders in pixels
+ :return: Line vertices and Text label or None
+ :rtype: 2-tuple (2x2 numpy.array of float, Text2D)
+ """
+ label, vertices = None, None
+
+ xCoord, yCoord = marker['x'], marker['y']
+ assert xCoord is None or yCoord is None # Specific to line markers
+
+ # Get plot corners in data coords
+ plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels()
+
+ corners = [(plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop)]
+ corners = numpy.array([self.pixelToData(x, y, axis='left', check=False)
+ for (x, y) in corners])
+
+ borders = {
+ 'right': (corners[3], corners[2]),
+ 'top': (corners[0], corners[3]),
+ 'bottom': (corners[2], corners[1]),
+ 'left': (corners[1], corners[0])
+ }
+
+ textLayouts = { # align, valign, offsets
+ 'right': (RIGHT, BOTTOM, (-1., -1.)),
+ 'top': (LEFT, TOP, (1., 1.)),
+ 'bottom': (LEFT, BOTTOM, (1., -1.)),
+ 'left': (LEFT, BOTTOM, (1., -1.))
+ }
+
+ if xCoord is None: # Horizontal line in data space
+ if marker['text'] is not None:
+ # Find intersection of hline with borders in data
+ # Order is important as it stops at first intersection
+ for border_name in ('right', 'top', 'bottom', 'left'):
+ (x0, y0), (x1, y1) = borders[border_name]
+
+ if min(y0, y1) <= yCoord < max(y0, y1):
+ xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0
+
+ # Add text label
+ pixelPos = self.dataToPixel(
+ xIntersect, yCoord, axis='left', check=False)
+
+ align, valign, offsets = textLayouts[border_name]
+
+ x = pixelPos[0] + offsets[0] * pixelOffset
+ y = pixelPos[1] + offsets[1] * pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=align, valign=valign)
+ break # Stop at first intersection
+
+ xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
+ vertices = numpy.array(
+ ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32)
+
+ else: # yCoord is None: vertical line in data space
+ if marker['text'] is not None:
+ # Find intersection of hline with borders in data
+ # Order is important as it stops at first intersection
+ for border_name in ('top', 'bottom', 'right', 'left'):
+ (x0, y0), (x1, y1) = borders[border_name]
+ if min(x0, x1) <= xCoord < max(x0, x1):
+ yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0
+
+ # Add text label
+ pixelPos = self.dataToPixel(
+ xCoord, yIntersect, axis='left', check=False)
+
+ align, valign, offsets = textLayouts[border_name]
+
+ x = pixelPos[0] + offsets[0] * pixelOffset
+ y = pixelPos[1] + offsets[1] * pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=align, valign=valign)
+ break # Stop at first intersection
+
+ yMin, yMax = corners[:, 1].min(), corners[:, 1].max()
+ vertices = numpy.array(
+ ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32)
+
+ return vertices, label
+
+ def _renderMarkersGL(self):
+ if len(self._markers) == 0:
+ return
+
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+
+ isXLog = self._plotFrame.xAxis.isLog
+ isYLog = self._plotFrame.yAxis.isLog
+
+ # Render in plot area
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ gl.glViewport(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+
+ # Prepare vertical and horizontal markers rendering
+ self._progBase.use()
+ gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self._plotFrame.transformedDataProjMat)
+ gl.glUniform2i(self._progBase.uniforms['isLog'], isXLog, isYLog)
+ gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+ posAttrib = self._progBase.attributes['position']
+
+ labels = []
+ pixelOffset = 3
+
+ for marker in self._markers.values():
+ xCoord, yCoord = marker['x'], marker['y']
+
+ if ((isXLog and xCoord is not None and
+ xCoord < FLOAT32_MINPOS) or
+ (isYLog and yCoord is not None and
+ yCoord < FLOAT32_MINPOS)):
+ # Do not render markers with negative coords on log axis
+ continue
+
+ if xCoord is None or yCoord is None:
+ if not self.isDefaultBaseVectors(): # Non-orthogonal axes
+ vertices, label = self._nonOrthoAxesLineMarkerPrimitives(
+ marker, pixelOffset)
+ if label is not None:
+ labels.append(label)
+
+ else: # Orthogonal axes
+ pixelPos = self.dataToPixel(
+ xCoord, yCoord, axis='left', check=False)
+
+ if xCoord is None: # Horizontal line in data space
+ if marker['text'] is not None:
+ x = self._plotFrame.size[0] - \
+ self._plotFrame.margins.right - pixelOffset
+ y = pixelPos[1] - pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=RIGHT, valign=BOTTOM)
+ labels.append(label)
+
+ xMin, xMax = self._plotFrame.dataRanges.x
+ vertices = numpy.array(((xMin, yCoord),
+ (xMax, yCoord)),
+ dtype=numpy.float32)
+
+ else: # yCoord is None: vertical line in data space
+ if marker['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=LEFT, valign=TOP)
+ labels.append(label)
+
+ yMin, yMax = self._plotFrame.dataRanges.y
+ vertices = numpy.array(((xCoord, yMin),
+ (xCoord, yMax)),
+ dtype=numpy.float32)
+
+ self._progBase.use()
+
+ gl.glUniform4f(self._progBase.uniforms['color'],
+ *marker['color'])
+
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+ gl.glLineWidth(1)
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ else:
+ pixelPos = self.dataToPixel(
+ xCoord, yCoord, axis='left', check=True)
+ if pixelPos is None:
+ # Do not render markers outside visible plot area
+ continue
+
+ if marker['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = pixelPos[1] + pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=LEFT, valign=TOP)
+ labels.append(label)
+
+ # For now simple implementation: using a curve for each marker
+ # Should pack all markers to a single set of points
+ markerCurve = GLPlotCurve2D(
+ numpy.array((xCoord,), dtype=numpy.float32),
+ numpy.array((yCoord,), dtype=numpy.float32),
+ marker=marker['symbol'],
+ markerColor=marker['color'],
+ markerSize=11)
+ markerCurve.render(self._plotFrame.transformedDataProjMat,
+ isXLog, isYLog)
+ markerCurve.discard()
+
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ # Render marker labels
+ for label in labels:
+ label.render(self.matScreenProj)
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def _renderOverlayGL(self):
+ # Render selection area and crosshair cursor
+ if self._selectionAreas or self._crosshairCursor is not None:
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+
+ # Scissor to plot area
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ self._progBase.use()
+ gl.glUniform2i(self._progBase.uniforms['isLog'],
+ self._plotFrame.xAxis.isLog,
+ self._plotFrame.yAxis.isLog)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+ posAttrib = self._progBase.attributes['position']
+ matrixUnif = self._progBase.uniforms['matrix']
+ colorUnif = self._progBase.uniforms['color']
+ hatchStepUnif = self._progBase.uniforms['hatchStep']
+
+ # Render selection area in plot area
+ if self._selectionAreas:
+ gl.glViewport(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+
+ gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
+ self._plotFrame.transformedDataProjMat)
+
+ for shape in self._selectionAreas.values():
+ if shape.isVideoInverted:
+ gl.glBlendFunc(gl.GL_ONE_MINUS_DST_COLOR, gl.GL_ZERO)
+
+ shape.render(posAttrib, colorUnif, hatchStepUnif)
+
+ if shape.isVideoInverted:
+ self._setBlendFuncGL()
+
+ # Render crosshair cursor is screen frame but with scissor
+ if (self._crosshairCursor is not None and
+ self._mousePosInPixels is not None):
+ gl.glViewport(
+ 0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
+ self.matScreenProj)
+
+ color, lineWidth = self._crosshairCursor
+ gl.glUniform4f(colorUnif, *color)
+ gl.glUniform1i(hatchStepUnif, 0)
+
+ xPixel, yPixel = self._mousePosInPixels
+ xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
+ vertices = numpy.array(((0., yPixel),
+ (self._plotFrame.size[0], yPixel),
+ (xPixel, 0.),
+ (xPixel, self._plotFrame.size[1])),
+ dtype=numpy.float32)
+
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+ gl.glLineWidth(lineWidth)
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def _renderPlotAreaGL(self):
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+
+ self._plotFrame.renderGrid()
+
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ # Matrix
+ trBounds = self._plotFrame.transformedDataRanges
+ if trBounds.x[0] == trBounds.x[1] or \
+ trBounds.y[0] == trBounds.y[1]:
+ return
+
+ isXLog = self._plotFrame.xAxis.isLog
+ isYLog = self._plotFrame.yAxis.isLog
+
+ gl.glViewport(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+
+ # Render images and curves
+ # sorted is stable: original order is preserved when key is the same
+ for item in self._plotContent.zOrderedPrimitives():
+ if item.info.get('yAxis') == 'right':
+ item.render(self._plotFrame.transformedDataY2ProjMat,
+ isXLog, isYLog)
+ else:
+ item.render(self._plotFrame.transformedDataProjMat,
+ isXLog, isYLog)
+
+ # Render Items
+ self._progBase.use()
+ gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self._plotFrame.transformedDataProjMat)
+ gl.glUniform2i(self._progBase.uniforms['isLog'],
+ self._plotFrame.xAxis.isLog,
+ self._plotFrame.yAxis.isLog)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ for item in self._items.values():
+ shape2D = item.get('_shape2D')
+ if shape2D is None:
+ shape2D = Shape2D(tuple(zip(item['x'], item['y'])),
+ fill=item['fill'],
+ fillColor=item['color'],
+ stroke=True,
+ strokeColor=item['color'])
+ item['_shape2D'] = shape2D
+
+ if ((isXLog and shape2D.xMin < FLOAT32_MINPOS) or
+ (isYLog and shape2D.yMin < FLOAT32_MINPOS)):
+ # Ignore items <= 0. on log axes
+ continue
+
+ posAttrib = self._progBase.attributes['position']
+ colorUnif = self._progBase.uniforms['color']
+ hatchStepUnif = self._progBase.uniforms['hatchStep']
+ shape2D.render(posAttrib, colorUnif, hatchStepUnif)
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def resizeGL(self, width, height):
+ if width == 0 or height == 0: # Do not resize
+ return
+ self._plotFrame.size = width, height
+
+ self.matScreenProj = mat4Ortho(0, self._plotFrame.size[0],
+ self._plotFrame.size[1], 0,
+ 1, -1)
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ # Add methods
+
+ def addCurve(self, x, y, legend,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror, z, selectable,
+ fill, alpha, symbolsize):
+ for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
+ yaxis, z, selectable, fill, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ x = numpy.array(x, dtype=numpy.float32, copy=False, order='C')
+ y = numpy.array(y, dtype=numpy.float32, copy=False, order='C')
+ if xerror is not None:
+ xerror = numpy.array(
+ xerror, dtype=numpy.float32, copy=False, order='C')
+ if yerror is not None:
+ yerror = numpy.array(
+ yerror, dtype=numpy.float32, copy=False, order='C')
+
+ # TODO check and improve this
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float32) / 255.
+
+ if isinstance(color, numpy.ndarray) and color.ndim == 2:
+ colorArray = color
+ color = None
+ else:
+ colorArray = None
+ color = Colors.rgba(color)
+
+ if alpha < 1.: # Apply image transparency
+ if colorArray is not None and colorArray.shape[1] == 4:
+ # multiply alpha channel
+ colorArray[:, 3] = colorArray[:, 3] * alpha
+ if color is not None:
+ color = color[0], color[1], color[2], color[3] * alpha
+
+ behaviors = set()
+ if selectable:
+ behaviors.add('selectable')
+
+ curve = GLPlotCurve2D(x, y, colorArray,
+ xError=xerror,
+ yError=yerror,
+ lineStyle=linestyle,
+ lineColor=color,
+ lineWidth=linewidth,
+ marker=symbol,
+ markerColor=color,
+ markerSize=symbolsize,
+ fillColor=color if fill else None)
+ curve.info = {
+ 'legend': legend,
+ 'zOrder': z,
+ 'behaviors': behaviors,
+ 'yAxis': 'left' if yaxis is None else yaxis,
+ }
+
+ if yaxis == "right":
+ self._plotFrame.isY2Axis = True
+
+ self._plotContent.add(curve)
+
+ return legend, 'curve'
+
+ def addImage(self, data, legend,
+ origin, scale, z,
+ selectable, draggable,
+ colormap, alpha):
+ for parameter in (data, legend, origin, scale, z,
+ selectable, draggable):
+ assert parameter is not None
+
+ behaviors = set()
+ if selectable:
+ behaviors.add('selectable')
+ if draggable:
+ behaviors.add('draggable')
+
+ if data.ndim == 2:
+ # Ensure array is contiguous and eventually convert its type
+ if data.dtype in (numpy.float32, numpy.uint8, numpy.uint16):
+ data = numpy.array(data, copy=False, order='C')
+ else:
+ _logger.info(
+ 'addImage: Convert %s data to float32', str(data.dtype))
+ data = numpy.array(data, dtype=numpy.float32, order='C')
+
+ colormapIsLog = colormap['normalization'].startswith('log')
+
+ if colormap['autoscale']:
+ cmapRange = None
+ else:
+ cmapRange = colormap['vmin'], colormap['vmax']
+ assert cmapRange[0] <= cmapRange[1]
+
+ # Retrieve colormap LUT from name and color array
+ colormapLut = Colors.applyColormapToData(
+ numpy.arange(256, dtype=numpy.uint8),
+ name=colormap['name'],
+ normalization='linear',
+ autoscale=False,
+ vmin=0,
+ vmax=255,
+ colors=colormap.get('colors'))
+
+ image = GLPlotColormap(data,
+ origin,
+ scale,
+ colormapLut,
+ colormapIsLog,
+ cmapRange,
+ alpha)
+ image.info = {
+ 'legend': legend,
+ 'zOrder': z,
+ 'behaviors': behaviors
+ }
+ self._plotContent.add(image)
+
+ elif len(data.shape) == 3:
+ # For RGB, RGBA data
+ assert data.shape[2] in (3, 4)
+ assert data.dtype in (numpy.float32, numpy.uint8)
+
+ image = GLPlotRGBAImage(data, origin, scale, alpha)
+
+ image.info = {
+ 'legend': legend,
+ 'zOrder': z,
+ 'behaviors': behaviors
+ }
+
+ if self._plotFrame.xAxis.isLog and image.xMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and image.yMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with Y <= 0 with Y axis log scale')
+
+ self._plotContent.add(image)
+
+ else:
+ raise RuntimeError("Unsupported data shape {0}".format(data.shape))
+
+ return legend, 'image'
+
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ # TODO handle overlay
+ if shape not in ('polygon', 'rectangle', 'line', 'vline', 'hline'):
+ raise NotImplementedError("Unsupported shape {0}".format(shape))
+
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ if shape == 'rectangle':
+ xMin, xMax = x
+ x = numpy.array((xMin, xMin, xMax, xMax))
+ yMin, yMax = y
+ y = numpy.array((yMin, yMax, yMax, yMin))
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and x.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and y.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with Y <= 0 with Y axis log scale')
+
+ self._items[legend] = {
+ 'shape': shape,
+ 'color': Colors.rgba(color),
+ 'fill': 'hatch' if fill else None,
+ 'x': x,
+ 'y': y
+ }
+
+ return legend, 'item'
+
+ def addMarker(self, x, y, legend, text, color,
+ selectable, draggable,
+ symbol, constraint, overlay):
+ # TODO handle overlay
+
+ if symbol is None:
+ symbol = '+'
+
+ behaviors = set()
+ if selectable:
+ behaviors.add('selectable')
+ if draggable:
+ behaviors.add('draggable')
+
+ # Apply constraint to provided position
+ isConstraint = (draggable and constraint is not None and
+ x is not None and y is not None)
+ if isConstraint:
+ x, y = constraint(x, y)
+
+ if x is not None and self._plotFrame.xAxis.isLog and x <= 0.:
+ raise RuntimeError(
+ 'Cannot add marker with X <= 0 with X axis log scale')
+ if y is not None and self._plotFrame.yAxis.isLog and y <= 0.:
+ raise RuntimeError(
+ 'Cannot add marker with Y <= 0 with Y axis log scale')
+
+ self._markers[legend] = {
+ 'x': x,
+ 'y': y,
+ 'legend': legend,
+ 'text': text,
+ 'color': Colors.rgba(color),
+ 'behaviors': behaviors,
+ 'constraint': constraint if isConstraint else None,
+ 'symbol': symbol,
+ }
+
+ return legend, 'marker'
+
+ # Remove methods
+
+ def remove(self, item):
+ legend, kind = item
+
+ if kind == 'curve':
+ curve = self._plotContent.pop('curve', legend)
+ if curve is not None:
+ # Check if some curves remains on the right Y axis
+ y2AxisItems = (item for item in self._plotContent.primitives()
+ if item.info.get('yAxis', 'left') == 'right')
+ self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
+
+ self._glGarbageCollector.append(curve)
+
+ elif kind == 'image':
+ image = self._plotContent.pop('image', legend)
+ if image is not None:
+ self._glGarbageCollector.append(image)
+
+ elif kind == 'marker':
+ self._markers.pop(legend, False)
+
+ elif kind == 'item':
+ self._items.pop(legend, False)
+
+ else:
+ _logger.error('Unsupported kind: %s', str(kind))
+
+ # Interaction methods
+
+ _QT_CURSORS = {
+ None: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ cursor = self._QT_CURSORS[cursor]
+
+ super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if linestyle is not '-':
+ _logger.warning(
+ "BackendOpenGL.setGraphCursor linestyle parameter ignored")
+
+ if flag:
+ color = Colors.rgba(color)
+ crosshairCursor = color, linewidth
+ else:
+ crosshairCursor = None
+
+ if crosshairCursor != self._crosshairCursor:
+ self._crosshairCursor = crosshairCursor
+
+ _PICK_OFFSET = 3 # Offset in pixel used for picking
+
+ def _mouseInPlotArea(self, x, y):
+ xPlot = numpy.clip(
+ x, self._plotFrame.margins.left,
+ self._plotFrame.size[0] - self._plotFrame.margins.right - 1)
+ yPlot = numpy.clip(
+ y, self._plotFrame.margins.top,
+ self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1)
+ return xPlot, yPlot
+
+ def pickItems(self, x, y):
+ picked = []
+
+ dataPos = self.pixelToData(x, y, axis='left', check=True)
+ if dataPos is not None:
+ # Pick markers
+ for marker in reversed(list(self._markers.values())):
+ pixelPos = self.dataToPixel(
+ marker['x'], marker['y'], axis='left', check=False)
+ if pixelPos is None: # negative coord on a log axis
+ continue
+
+ if marker['x'] is None: # Horizontal line
+ pt1 = self.pixelToData(
+ x, y - self._PICK_OFFSET, axis='left', check=False)
+ pt2 = self.pixelToData(
+ x, y + self._PICK_OFFSET, axis='left', check=False)
+ isPicked = (min(pt1[1], pt2[1]) <= marker['y'] <=
+ max(pt1[1], pt2[1]))
+
+ elif marker['y'] is None: # Vertical line
+ pt1 = self.pixelToData(
+ x - self._PICK_OFFSET, y, axis='left', check=False)
+ pt2 = self.pixelToData(
+ x + self._PICK_OFFSET, y, axis='left', check=False)
+ isPicked = (min(pt1[0], pt2[0]) <= marker['x'] <=
+ max(pt1[0], pt2[0]))
+
+ else:
+ isPicked = (
+ numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and
+ numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET)
+
+ if isPicked:
+ picked.append(dict(kind='marker',
+ legend=marker['legend']))
+
+ # Pick image and curves
+ for item in self._plotContent.zOrderedPrimitives(reverse=True):
+ if isinstance(item, (GLPlotColormap, GLPlotRGBAImage)):
+ pickedPos = item.pick(*dataPos)
+ if pickedPos is not None:
+ picked.append(dict(kind='image',
+ legend=item.info['legend']))
+
+ elif isinstance(item, GLPlotCurve2D):
+ offset = self._PICK_OFFSET
+ if item.marker is not None:
+ offset = max(item.markerSize / 2., offset)
+ if item.lineStyle is not None:
+ offset = max(item.lineWidth / 2., offset)
+
+ yAxis = item.info['yAxis']
+
+ inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
+ dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=yAxis, check=True)
+ if dataPos is None:
+ continue
+ xPick0, yPick0 = dataPos
+
+ inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
+ dataPos = self.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=yAxis, check=True)
+ if dataPos is None:
+ continue
+ xPick1, yPick1 = dataPos
+
+ if xPick0 < xPick1:
+ xPickMin, xPickMax = xPick0, xPick1
+ else:
+ xPickMin, xPickMax = xPick1, xPick0
+
+ if yPick0 < yPick1:
+ yPickMin, yPickMax = yPick0, yPick1
+ else:
+ yPickMin, yPickMax = yPick1, yPick0
+
+ pickedIndices = item.pick(xPickMin, yPickMin,
+ xPickMax, yPickMax)
+ if pickedIndices:
+ picked.append(dict(kind='curve',
+ legend=item.info['legend'],
+ xdata=item.xData[pickedIndices],
+ ydata=item.yData[pickedIndices]))
+
+ return picked
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ pass # TODO
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self
+
+ def postRedisplay(self):
+ self._sigPostRedisplay.emit()
+
+ def replot(self):
+ self.update() # async redraw
+ # self.repaint() # immediate redraw
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ if dpi is not None:
+ _logger.warning("saveGraph ignores dpi parameter")
+
+ if fileFormat not in ['png', 'ppm', 'svg', 'tiff']:
+ raise NotImplementedError('Unsupported format: %s' % fileFormat)
+
+ self.makeCurrent()
+
+ data = numpy.empty(
+ (self._plotFrame.size[1], self._plotFrame.size[0], 3),
+ dtype=numpy.uint8, order='C')
+
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(0, 0, self._plotFrame.size[0], self._plotFrame.size[1],
+ gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ data = numpy.flipud(data)
+
+ # fileName is either a file-like object or a str
+ saveImageToFile(data, fileName, fileFormat)
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self._plotFrame.title = title
+
+ def setGraphXLabel(self, label):
+ self._plotFrame.xAxis.title = label
+
+ def setGraphYLabel(self, label, axis):
+ if axis == 'left':
+ self._plotFrame.yAxis.title = label
+ else: # right axis
+ if label:
+ _logger.warning('Right axis label not implemented')
+
+ # Non orthogonal axes
+
+ def setBaseVectors(self, x=(1., 0.), y=(0., 1.)):
+ """Set base vectors.
+
+ Useful for non-orthogonal axes.
+ If an axis is in log scale, skew is applied to log transformed values.
+
+ Base vector does not work well with log axes, to investi
+ """
+ if x != (1., 0.) and y != (0., 1.):
+ if self._plotFrame.xAxis.isLog:
+ _logger.warning("setBaseVectors disables X axis logarithmic.")
+ self.setXAxisLogarithmic(False)
+ if self._plotFrame.yAxis.isLog:
+ _logger.warning("setBaseVectors disables Y axis logarithmic.")
+ self.setYAxisLogarithmic(False)
+
+ if self.isKeepDataAspectRatio():
+ _logger.warning("setBaseVectors disables keepDataAspectRatio.")
+ self.keepDataAspectRatio(False)
+
+ self._plotFrame.baseVectors = x, y
+
+ def getBaseVectors(self):
+ return self._plotFrame.baseVectors
+
+ def isDefaultBaseVectors(self):
+ return self._plotFrame.baseVectors == \
+ self._plotFrame.DEFAULT_BASE_VECTORS
+
+ # Graph limits
+
+ def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
+ """Set the visible range of data in the plot frame.
+
+ This clips the ranges to possible values (takes care of float32
+ range + positive range for log).
+ This also takes care of non-orthogonal axes.
+
+ This should be moved to PlotFrame.
+ """
+ # Update axes range with a clipped range if too wide
+ self._plotFrame.setDataRanges(xlim, ylim, y2lim)
+
+ if not self.isDefaultBaseVectors():
+ # Update axes range with axes bounds in data coords
+ plotLeft, plotTop, plotWidth, plotHeight = \
+ self.getPlotBoundsInPixels()
+
+ self._plotFrame.xAxis.dataRange = sorted([
+ self.pixelToData(x, y, axis='left', check=False)[0]
+ for (x, y) in ((plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight))])
+
+ self._plotFrame.yAxis.dataRange = sorted([
+ self.pixelToData(x, y, axis='left', check=False)[1]
+ for (x, y) in ((plotLeft, plotTop + plotHeight),
+ (plotLeft, plotTop))])
+
+ self._plotFrame.y2Axis.dataRange = sorted([
+ self.pixelToData(x, y, axis='right', check=False)[1]
+ for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop))])
+
+ def _ensureAspectRatio(self, keepDim=None):
+ """Update plot bounds in order to keep aspect ratio.
+
+ Warning: keepDim on right Y axis is not implemented !
+
+ :param str keepDim: The dimension to maintain: 'x', 'y' or None.
+ If None (the default), the dimension with the largest range.
+ """
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth <= 2 or plotHeight <= 2:
+ return
+
+ if keepDim is None:
+ dataBounds = self._plotContent.getBounds(
+ self._plotFrame.xAxis.isLog, self._plotFrame.yAxis.isLog)
+ if dataBounds.yAxis.range_ != 0.:
+ dataRatio = dataBounds.xAxis.range_
+ dataRatio /= float(dataBounds.yAxis.range_)
+
+ plotRatio = plotWidth / float(plotHeight) # Test != 0 before
+
+ keepDim = 'x' if dataRatio > plotRatio else 'y'
+ else: # Limit case
+ keepDim = 'x'
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ if keepDim == 'y':
+ dataW = (yMax - yMin) * plotWidth / float(plotHeight)
+ xCenter = 0.5 * (xMin + xMax)
+ xMin = xCenter - 0.5 * dataW
+ xMax = xCenter + 0.5 * dataW
+ elif keepDim == 'x':
+ dataH = (xMax - xMin) * plotHeight / float(plotWidth)
+ yCenter = 0.5 * (yMin + yMax)
+ yMin = yCenter - 0.5 * dataH
+ yMax = yCenter + 0.5 * dataH
+ y2Center = 0.5 * (y2Min + y2Max)
+ y2Min = y2Center - 0.5 * dataH
+ y2Max = y2Center + 0.5 * dataH
+ else:
+ raise RuntimeError('Unsupported dimension to keep: %s' % keepDim)
+
+ # Update plot frame bounds
+ self._setDataRanges(xlim=(xMin, xMax),
+ ylim=(yMin, yMax),
+ y2lim=(y2Min, y2Max))
+
+ def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None,
+ keepDim=None):
+ # Update axes range with a clipped range if too wide
+ self._setDataRanges(xlim=xRange,
+ ylim=yRange,
+ y2lim=y2Range)
+
+ # Keep data aspect ratio
+ if self.isKeepDataAspectRatio():
+ self._ensureAspectRatio(keepDim)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ assert xmin < xmax
+ assert ymin < ymax
+
+ if y2min is None or y2max is None:
+ y2Range = None
+ else:
+ assert y2min < y2max
+ y2Range = y2min, y2max
+ self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range)
+
+ def getGraphXLimits(self):
+ return self._plotFrame.dataRanges.x
+
+ def setGraphXLimits(self, xmin, xmax):
+ assert xmin < xmax
+ self._setPlotBounds(xRange=(xmin, xmax), keepDim='x')
+
+ def getGraphYLimits(self, axis):
+ assert axis in ("left", "right")
+ if axis == "left":
+ return self._plotFrame.dataRanges.y
+ else:
+ return self._plotFrame.dataRanges.y2
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ assert ymin < ymax
+ assert axis in ("left", "right")
+
+ if axis == "left":
+ self._setPlotBounds(yRange=(ymin, ymax), keepDim='y')
+ else:
+ self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y')
+
+ # Graph axes
+
+ def setXAxisLogarithmic(self, flag):
+ if flag != self._plotFrame.xAxis.isLog:
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ if flag and not self.isDefaultBaseVectors():
+ _logger.warning(
+ "setXAxisLogarithmic ignored because baseVectors are set")
+ return
+
+ self._plotFrame.xAxis.isLog = flag
+
+ def setYAxisLogarithmic(self, flag):
+ if (flag != self._plotFrame.yAxis.isLog or
+ flag != self._plotFrame.y2Axis.isLog):
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ if flag and not self.isDefaultBaseVectors():
+ _logger.warning(
+ "setYAxisLogarithmic ignored because baseVectors are set")
+ return
+
+ self._plotFrame.yAxis.isLog = flag
+ self._plotFrame.y2Axis.isLog = flag
+
+ def setYAxisInverted(self, flag):
+ if flag != self._plotFrame.isYAxisInverted:
+ self._plotFrame.isYAxisInverted = flag
+
+ def isYAxisInverted(self):
+ return self._plotFrame.isYAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
+ return False
+ else:
+ return self._keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ if flag and (self._plotFrame.xAxis.isLog or
+ self._plotFrame.yAxis.isLog):
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+ if flag and not self.isDefaultBaseVectors():
+ _logger.warning(
+ "keepDataAspectRatio ignored because baseVectors are set")
+
+ self._keepDataAspectRatio = flag
+
+ def setGraphGrid(self, which):
+ assert which in (None, 'major', 'both')
+ self._plotFrame.grid = which is not None # TODO True grid support
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis, check=False):
+ assert axis in ('left', 'right')
+
+ if x is None or y is None:
+ dataBounds = self._plotContent.getBounds(
+ self._plotFrame.xAxis.isLog, self._plotFrame.yAxis.isLog)
+
+ if x is None:
+ x = dataBounds.xAxis.center
+
+ if y is None:
+ if axis == 'left':
+ y = dataBounds.yAxis.center
+ else:
+ y = dataBounds.y2Axis.center
+
+ result = self._plotFrame.dataToPixel(x, y, axis)
+
+ if check and result is not None:
+ xPixel, yPixel = result
+ width, height = self._plotFrame.size
+ if (xPixel < self._plotFrame.margins.left or
+ xPixel > (width - self._plotFrame.margins.right) or
+ yPixel < self._plotFrame.margins.top or
+ yPixel > height - self._plotFrame.margins.bottom):
+ return None # (x, y) is out of plot area
+
+ return result
+
+ def pixelToData(self, x, y, axis, check):
+ assert axis in ("left", "right")
+
+ if x is None:
+ x = self._plotFrame.size[0] / 2.
+ if y is None:
+ y = self._plotFrame.size[1] / 2.
+
+ if check and (x < self._plotFrame.margins.left or
+ x > (self._plotFrame.size[0] -
+ self._plotFrame.margins.right) or
+ y < self._plotFrame.margins.top or
+ y > (self._plotFrame.size[1] -
+ self._plotFrame.margins.bottom)):
+ return None # (x, y) is out of plot area
+
+ return self._plotFrame.pixelToData(x, y, axis)
+
+ def getPlotBoundsInPixels(self):
+ return self._plotFrame.plotOrigin + self._plotFrame.plotSize
diff --git a/silx/gui/plot/backends/ModestImage.py b/silx/gui/plot/backends/ModestImage.py
new file mode 100644
index 0000000..93fba5a
--- /dev/null
+++ b/silx/gui/plot/backends/ModestImage.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Matplotlib computationally modest image class."""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/02/2016"
+
+
+import numpy
+
+from matplotlib import cbook
+from matplotlib.image import AxesImage
+
+
+class ModestImage(AxesImage):
+ """Computationally modest image class.
+
+Customization of https://github.com/ChrisBeaumont/ModestImage to allow
+extent support.
+
+ModestImage is an extension of the Matplotlib AxesImage class
+better suited for the interactive display of larger images. Before
+drawing, ModestImage resamples the data array based on the screen
+resolution and view window. This has very little affect on the
+appearance of the image, but can substantially cut down on
+computation since calculations of unresolved or clipped pixels
+are skipped.
+
+The interface of ModestImage is the same as AxesImage. However, it
+does not currently support setting the 'extent' property. There
+may also be weird coordinate warping operations for images that
+I'm not aware of. Don't expect those to work either.
+"""
+ def __init__(self, *args, **kwargs):
+ self._full_res = None
+ self._sx, self._sy = None, None
+ self._bounds = (None, None, None, None)
+ self._origExtent = None
+ super(ModestImage, self).__init__(*args, **kwargs)
+ if 'extent' in kwargs and kwargs['extent'] is not None:
+ self.set_extent(kwargs['extent'])
+
+ def set_extent(self, extent):
+ super(ModestImage, self).set_extent(extent)
+ if self._origExtent is None:
+ self._origExtent = self.get_extent()
+
+ def get_image_extent(self):
+ """Returns the extent of the whole image.
+
+ get_extent returns the extent of the drawn area and not of the full
+ image.
+
+ :return: Bounds of the image (x0, x1, y0, y1).
+ :rtype: Tuple of 4 floats.
+ """
+ if self._origExtent is not None:
+ return self._origExtent
+ else:
+ return self.get_extent()
+
+ def set_data(self, A):
+ """
+ Set the image array
+
+ ACCEPTS: numpy/PIL Image A
+ """
+
+ self._full_res = A
+ self._A = A
+
+ if (self._A.dtype != numpy.uint8 and
+ not numpy.can_cast(self._A.dtype, numpy.float)):
+ raise TypeError("Image data can not convert to float")
+
+ if (self._A.ndim not in (2, 3) or
+ (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
+ raise TypeError("Invalid dimensions for image data")
+
+ self._imcache = None
+ self._rgbacache = None
+ self._oldxslice = None
+ self._oldyslice = None
+ self._sx, self._sy = None, None
+
+ def get_array(self):
+ """Override to return the full-resolution array"""
+ return self._full_res
+
+ def _scale_to_res(self):
+ """ Change self._A and _extent to render an image whose
+resolution is matched to the eventual rendering."""
+ # extent has to be set BEFORE set_data
+ if self._origExtent is None:
+ if self.origin == "upper":
+ self._origExtent = (0, self._full_res.shape[1],
+ self._full_res.shape[0], 0)
+ else:
+ self._origExtent = (0, self._full_res.shape[1],
+ 0, self._full_res.shape[0])
+
+ if self.origin == "upper":
+ origXMin, origXMax, origYMax, origYMin = self._origExtent[0:4]
+ else:
+ origXMin, origXMax, origYMin, origYMax = self._origExtent[0:4]
+ ax = self.axes
+ ext = ax.transAxes.transform([1, 1]) - ax.transAxes.transform([0, 0])
+ xlim, ylim = ax.get_xlim(), ax.get_ylim()
+ xlim = max(xlim[0], origXMin), min(xlim[1], origXMax)
+ if ylim[0] > ylim[1]:
+ ylim = max(ylim[1], origYMin), min(ylim[0], origYMax)
+ else:
+ ylim = max(ylim[0], origYMin), min(ylim[1], origYMax)
+ # print("THOSE LIMITS ARE TO BE COMPARED WITH THE EXTENT")
+ # print("IN ORDER TO KNOW WHAT IT IS LIMITING THE DISPLAY")
+ # print("IF THE AXES OR THE EXTENT")
+ dx, dy = xlim[1] - xlim[0], ylim[1] - ylim[0]
+
+ y0 = max(0, ylim[0] - 5)
+ y1 = min(self._full_res.shape[0], ylim[1] + 5)
+ x0 = max(0, xlim[0] - 5)
+ x1 = min(self._full_res.shape[1], xlim[1] + 5)
+ y0, y1, x0, x1 = [int(a) for a in [y0, y1, x0, x1]]
+
+ sy = int(max(1, min((y1 - y0) / 5., numpy.ceil(dy / ext[1]))))
+ sx = int(max(1, min((x1 - x0) / 5., numpy.ceil(dx / ext[0]))))
+
+ # have we already calculated what we need?
+ if (self._sx is not None) and (self._sy is not None):
+ if (sx >= self._sx and sy >= self._sy and
+ x0 >= self._bounds[0] and x1 <= self._bounds[1] and
+ y0 >= self._bounds[2] and y1 <= self._bounds[3]):
+ return
+
+ self._A = self._full_res[y0:y1:sy, x0:x1:sx]
+ self._A = cbook.safe_masked_invalid(self._A)
+ x1 = x0 + self._A.shape[1] * sx
+ y1 = y0 + self._A.shape[0] * sy
+
+ if self.origin == "upper":
+ self.set_extent([x0, x1, y1, y0])
+ else:
+ self.set_extent([x0, x1, y0, y1])
+ self._sx = sx
+ self._sy = sy
+ self._bounds = (x0, x1, y0, y1)
+ self.changed()
+
+ def draw(self, renderer, *args, **kwargs):
+ self._scale_to_res()
+ super(ModestImage, self).draw(renderer, *args, **kwargs)
diff --git a/silx/gui/plot/backends/__init__.py b/silx/gui/plot/backends/__init__.py
new file mode 100644
index 0000000..966d9df
--- /dev/null
+++ b/silx/gui/plot/backends/__init__.py
@@ -0,0 +1,29 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package implements the backend of the Plot."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
diff --git a/silx/gui/plot/backends/_matplotlib.py b/silx/gui/plot/backends/_matplotlib.py
new file mode 100644
index 0000000..26732a0
--- /dev/null
+++ b/silx/gui/plot/backends/_matplotlib.py
@@ -0,0 +1,64 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module inits matplotlib and setups the backend to use.
+
+It MUST be imported prior to any other import of matplotlib.
+
+It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
+to the used backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/10/2016"
+
+
+import sys
+import logging
+
+
+_logger = logging.getLogger(__name__)
+
+if 'matplotlib' in sys.modules:
+ _logger.warning(
+ 'matplotlib already loaded, setting its backend may not work')
+
+
+from ... import qt
+
+import matplotlib
+
+if qt.BINDING == 'PySide':
+ matplotlib.rcParams['backend'] = 'Qt4Agg'
+ matplotlib.rcParams['backend.qt4'] = 'PySide'
+ from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa
+
+elif qt.BINDING == 'PyQt4':
+ matplotlib.rcParams['backend'] = 'Qt4Agg'
+ from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa
+
+elif qt.BINDING == 'PyQt5':
+ matplotlib.rcParams['backend'] = 'Qt5Agg'
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
new file mode 100644
index 0000000..4f08054
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -0,0 +1,1317 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides classes to render 2D lines and scatter plots
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import logging
+
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl
+from ...._glutils import numpyToGLType, Program, vertexBuffer
+from ..._utils import FLOAT32_MINPOS
+from .GLSupport import buildFillMaskIndices
+
+
+_logger = logging.getLogger(__name__)
+
+
+_MPL_NONES = None, 'None', '', ' '
+
+
+# fill ########################################################################
+
+class _Fill2D(object):
+ _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3
+
+ _SHADERS = {
+ 'vertexTransforms': {
+ _LINEAR: """
+ vec4 transformXY(float x, float y) {
+ return vec4(x, y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_X: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x), y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(x, oneOverLog10 * log(y), 0.0, 1.0);
+ }
+ """,
+ _LOG10_X_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x),
+ oneOverLog10 * log(y),
+ 0.0, 1.0);
+ }
+ """
+ },
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+
+ %s
+
+ void main(void) {
+ gl_Position = matrix * transformXY(xPos, yPos);
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform vec4 color;
+
+ void main(void) {
+ gl_FragColor = color;
+ }
+ """
+ }
+
+ _programs = {
+ _LINEAR: Program(
+ _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LINEAR],
+ _SHADERS['fragment'], attrib0='xPos'),
+ _LOG10_X: Program(
+ _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X],
+ _SHADERS['fragment'], attrib0='xPos'),
+ _LOG10_Y: Program(
+ _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_Y],
+ _SHADERS['fragment'], attrib0='xPos'),
+ _LOG10_X_Y: Program(
+ _SHADERS['vertex'] % _SHADERS['vertexTransforms'][_LOG10_X_Y],
+ _SHADERS['fragment'], attrib0='xPos'),
+ }
+
+ def __init__(self, xFillVboData=None, yFillVboData=None,
+ xMin=None, yMin=None, xMax=None, yMax=None,
+ color=(0., 0., 0., 1.)):
+ self.xFillVboData = xFillVboData
+ self.yFillVboData = yFillVboData
+ self.xMin, self.yMin = xMin, yMin
+ self.xMax, self.yMax = xMax, yMax
+ self.color = color
+
+ self._bboxVertices = None
+ self._indices = None
+ self._indicesType = None
+
+ def prepare(self):
+ if self._indices is None:
+ self._indices = buildFillMaskIndices(self.xFillVboData.size)
+ self._indicesType = numpyToGLType(self._indices.dtype)
+
+ if self._bboxVertices is None:
+ yMin, yMax = min(self.yMin, 1e-32), max(self.yMax, 1e-32)
+ self._bboxVertices = numpy.array(((self.xMin, self.xMin,
+ self.xMax, self.xMax),
+ (yMin, yMax, yMin, yMax)),
+ dtype=numpy.float32)
+
+ def render(self, matrix, isXLog, isYLog):
+ self.prepare()
+
+ if isXLog:
+ transform = self._LOG10_X_Y if isYLog else self._LOG10_X
+ else:
+ transform = self._LOG10_Y if isYLog else self._LINEAR
+
+ prog = self._programs[transform]
+ prog.use()
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+
+ gl.glUniform4f(prog.uniforms['color'], *self.color)
+
+ xPosAttrib = prog.attributes['xPos']
+ yPosAttrib = prog.attributes['yPos']
+
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self.xFillVboData.setVertexAttrib(xPosAttrib)
+
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self.yFillVboData.setVertexAttrib(yPosAttrib)
+
+ # Prepare fill mask
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawElements(gl.GL_TRIANGLE_STRIP, self._indices.size,
+ self._indicesType, self._indices)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ gl.glVertexAttribPointer(xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ self._bboxVertices[0])
+ gl.glVertexAttribPointer(yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ self._bboxVertices[1])
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._bboxVertices[0].size)
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+
+# line ########################################################################
+
+SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
+
+
+class _Lines2D(object):
+ STYLES = SOLID, DASHED, DASHDOT, DOTTED
+ """Supported line styles"""
+
+ _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3
+
+ _SHADERS = {
+ 'vertexTransforms': {
+ _LINEAR: """
+ vec4 transformXY(float x, float y) {
+ return vec4(x, y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_X: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x), y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(x, oneOverLog10 * log(y), 0.0, 1.0);
+ }
+ """,
+ _LOG10_X_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x),
+ oneOverLog10 * log(y),
+ 0.0, 1.0);
+ }
+ """
+ },
+ 'solid': {
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ gl_Position = matrix * transformXY(xPos, yPos);
+ vColor = color;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ }
+ """
+ },
+
+
+ # Limitation: Dash using an estimate of distance in screen coord
+ # to avoid computing distance when viewport is resized
+ # results in inequal dashes when viewport aspect ratio is far from 1
+ 'dashed': {
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ uniform vec2 halfViewportSize;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+ attribute float distance;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ gl_Position = matrix * transformXY(xPos, yPos);
+ //Estimate distance in pixels
+ vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) *
+ halfViewportSize;
+ float pixelPerDataEstimate = length(probe)/sqrt(2.);
+ vDist = distance * pixelPerDataEstimate;
+ vColor = color;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ /* Dashes: [0, x], [y, z]
+ Dash period: w */
+ uniform vec4 dash;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ float dist = mod(vDist, dash.w);
+ if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
+ discard;
+ }
+ gl_FragColor = vColor;
+ }
+ """
+ }
+ }
+
+ _programs = {}
+
+ def __init__(self, xVboData=None, yVboData=None,
+ colorVboData=None, distVboData=None,
+ style=SOLID, color=(0., 0., 0., 1.),
+ width=1, dashPeriod=20, drawMode=None):
+ self.xVboData = xVboData
+ self.yVboData = yVboData
+ self.distVboData = distVboData
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ self.color = color
+ self._width = 1
+ self.width = width
+ self._style = None
+ self.style = style
+ self.dashPeriod = dashPeriod
+
+ self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
+
+ @property
+ def style(self):
+ return self._style
+
+ @style.setter
+ def style(self, style):
+ if style in _MPL_NONES:
+ self._style = None
+ self.render = self._renderNone
+ else:
+ assert style in self.STYLES
+ self._style = style
+ if style == SOLID:
+ self.render = self._renderSolid
+ else: # DASHED, DASHDOT, DOTTED
+ self.render = self._renderDash
+
+ @property
+ def width(self):
+ return self._width
+
+ @width.setter
+ def width(self, width):
+ # try:
+ # widthRange = self._widthRange
+ # except AttributeError:
+ # widthRange = gl.glGetFloatv(gl.GL_ALIASED_LINE_WIDTH_RANGE)
+ # # Shared among contexts, this should be enough..
+ # _Lines2D._widthRange = widthRange
+ # assert width >= widthRange[0] and width <= widthRange[1]
+ self._width = width
+
+ @classmethod
+ def _getProgram(cls, transform, style):
+ try:
+ prgm = cls._programs[(transform, style)]
+ except KeyError:
+ sources = cls._SHADERS[style]
+ vertexShdr = sources['vertex'] % \
+ cls._SHADERS['vertexTransforms'][transform]
+ prgm = Program(vertexShdr, sources['fragment'], attrib0='xPos')
+ cls._programs[(transform, style)] = prgm
+ return prgm
+
+ @classmethod
+ def init(cls):
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ def _renderNone(self, matrix, isXLog, isYLog):
+ pass
+
+ render = _renderNone # Overridden in style setter
+
+ def _renderSolid(self, matrix, isXLog, isYLog):
+ if isXLog:
+ transform = self._LOG10_X_Y if isYLog else self._LOG10_X
+ else:
+ transform = self._LOG10_Y if isYLog else self._LINEAR
+
+ prog = self._getProgram(transform, 'solid')
+ prog.use()
+
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+
+ colorAttrib = prog.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(colorAttrib)
+ self.colorVboData.setVertexAttrib(colorAttrib)
+ else:
+ gl.glDisableVertexAttribArray(colorAttrib)
+ gl.glVertexAttrib4f(colorAttrib, *self.color)
+
+ xPosAttrib = prog.attributes['xPos']
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self.xVboData.setVertexAttrib(xPosAttrib)
+
+ yPosAttrib = prog.attributes['yPos']
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self.yVboData.setVertexAttrib(yPosAttrib)
+
+ gl.glLineWidth(self.width)
+ gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
+
+ gl.glDisable(gl.GL_LINE_SMOOTH)
+
+ def _renderDash(self, matrix, isXLog, isYLog):
+ if isXLog:
+ transform = self._LOG10_X_Y if isYLog else self._LOG10_X
+ else:
+ transform = self._LOG10_Y if isYLog else self._LINEAR
+
+ prog = self._getProgram(transform, 'dashed')
+ prog.use()
+
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+ x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT)
+ gl.glUniform2f(prog.uniforms['halfViewportSize'],
+ 0.5 * viewWidth, 0.5 * viewHeight)
+
+ if self.style == DOTTED:
+ dash = (0.1 * self.dashPeriod,
+ 0.6 * self.dashPeriod,
+ 0.7 * self.dashPeriod,
+ self.dashPeriod)
+ elif self.style == DASHDOT:
+ dash = (0.3 * self.dashPeriod,
+ 0.5 * self.dashPeriod,
+ 0.6 * self.dashPeriod,
+ self.dashPeriod)
+ else:
+ dash = (0.5 * self.dashPeriod,
+ self.dashPeriod,
+ self.dashPeriod,
+ self.dashPeriod)
+
+ gl.glUniform4f(prog.uniforms['dash'], *dash)
+
+ colorAttrib = prog.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(colorAttrib)
+ self.colorVboData.setVertexAttrib(colorAttrib)
+ else:
+ gl.glDisableVertexAttribArray(colorAttrib)
+ gl.glVertexAttrib4f(colorAttrib, *self.color)
+
+ distAttrib = prog.attributes['distance']
+ gl.glEnableVertexAttribArray(distAttrib)
+ self.distVboData.setVertexAttrib(distAttrib)
+
+ xPosAttrib = prog.attributes['xPos']
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self.xVboData.setVertexAttrib(xPosAttrib)
+
+ yPosAttrib = prog.attributes['yPos']
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self.yVboData.setVertexAttrib(yPosAttrib)
+
+ gl.glLineWidth(self.width)
+ gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
+
+ gl.glDisable(gl.GL_LINE_SMOOTH)
+
+
+def _distancesFromArrays(xData, yData):
+ deltas = numpy.dstack((
+ numpy.ediff1d(xData, to_begin=numpy.float32(0.)),
+ numpy.ediff1d(yData, to_begin=numpy.float32(0.))))[0]
+ return numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1)))
+
+
+# points ######################################################################
+
+DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \
+ 'd', 'o', 's', '+', 'x', '.', ',', '*'
+
+H_LINE, V_LINE = '_', '|'
+
+
+class _Points2D(object):
+ MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK,
+ H_LINE, V_LINE)
+
+ _LINEAR, _LOG10_X, _LOG10_Y, _LOG10_X_Y = 0, 1, 2, 3
+
+ _SHADERS = {
+ 'vertexTransforms': {
+ _LINEAR: """
+ vec4 transformXY(float x, float y) {
+ return vec4(x, y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_X: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x), y, 0.0, 1.0);
+ }
+ """,
+ _LOG10_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(x, oneOverLog10 * log(y), 0.0, 1.0);
+ }
+ """,
+ _LOG10_X_Y: """
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 transformXY(float x, float y) {
+ return vec4(oneOverLog10 * log(x),
+ oneOverLog10 * log(y),
+ 0.0, 1.0);
+ }
+ """
+ },
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ uniform int transform;
+ uniform float size;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ gl_Position = matrix * transformXY(xPos, yPos);
+ vColor = color;
+ gl_PointSize = size;
+ }
+ """,
+
+ 'fragmentSymbols': {
+ DIAMOND: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
+ float f = centerCoord.x + centerCoord.y;
+ return clamp(size * (0.5 - f), 0.0, 1.0);
+ }
+ """,
+ CIRCLE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+ """,
+ SQUARE: """
+ float alphaSymbol(vec2 coord, float size) {
+ return 1.0;
+ }
+ """,
+ PLUS: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
+ if (min(d.x, d.y) < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ X_MARKER: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_x.x, d_x.y) <= 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ ASTERISK: """
+ float alphaSymbol(vec2 coord, float size) {
+ /* Combining +, x and cirle */
+ vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_plus.x, d_plus.y) < 0.5) {
+ return 1.0;
+ } else if (min(d_x.x, d_x.y) <= 0.5) {
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (0.5 - r), 0.0, 1.0);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ H_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dy = abs(size * (coord.y - 0.5));
+ if (dy < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ V_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dx = abs(size * (coord.x - 0.5));
+ if (dx < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """
+ },
+
+ 'fragment': """
+ #version 120
+
+ uniform float size;
+
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ float alpha = alphaSymbol(gl_PointCoord, size);
+ if (alpha <= 0.0) {
+ discard;
+ } else {
+ gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0));
+ }
+ }
+ """
+ }
+
+ _programs = {}
+
+ def __init__(self, xVboData=None, yVboData=None, colorVboData=None,
+ marker=SQUARE, color=(0., 0., 0., 1.), size=7):
+ self.color = color
+ self._marker = None
+ self.marker = marker
+ self._size = 1
+ self.size = size
+
+ self.xVboData = xVboData
+ self.yVboData = yVboData
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ @property
+ def marker(self):
+ return self._marker
+
+ @marker.setter
+ def marker(self, marker):
+ if marker in _MPL_NONES:
+ self._marker = None
+ self.render = self._renderNone
+ else:
+ assert marker in self.MARKERS
+ self._marker = marker
+ self.render = self._renderMarkers
+
+ @property
+ def size(self):
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ # try:
+ # sizeRange = self._sizeRange
+ # except AttributeError:
+ # sizeRange = gl.glGetFloatv(gl.GL_POINT_SIZE_RANGE)
+ # # Shared among contexts, this should be enough..
+ # _Points2D._sizeRange = sizeRange
+ # assert size >= sizeRange[0] and size <= sizeRange[1]
+ self._size = size
+
+ @classmethod
+ def _getProgram(cls, transform, marker):
+ """On-demand shader program creation."""
+ if marker == PIXEL:
+ marker = SQUARE
+ elif marker == POINT:
+ marker = CIRCLE
+ try:
+ prgm = cls._programs[(transform, marker)]
+ except KeyError:
+ vertShdr = cls._SHADERS['vertex'] % \
+ cls._SHADERS['vertexTransforms'][transform]
+ fragShdr = cls._SHADERS['fragment'] % \
+ cls._SHADERS['fragmentSymbols'][marker]
+ prgm = Program(vertShdr, fragShdr, attrib0='xPos')
+
+ cls._programs[(transform, marker)] = prgm
+ return prgm
+
+ @classmethod
+ def init(cls):
+ version = gl.glGetString(gl.GL_VERSION)
+ majorVersion = int(version[0])
+ assert majorVersion >= 2
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ if majorVersion >= 3: # OpenGL 3
+ gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def _renderNone(self, matrix, isXLog, isYLog):
+ pass
+
+ render = _renderNone
+
+ def _renderMarkers(self, matrix, isXLog, isYLog):
+ if isXLog:
+ transform = self._LOG10_X_Y if isYLog else self._LOG10_X
+ else:
+ transform = self._LOG10_Y if isYLog else self._LINEAR
+
+ prog = self._getProgram(transform, self.marker)
+ prog.use()
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+ if self.marker == PIXEL:
+ size = 1
+ elif self.marker == POINT:
+ size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
+ else:
+ size = self.size
+ gl.glUniform1f(prog.uniforms['size'], size)
+ # gl.glPointSize(self.size)
+
+ cAttrib = prog.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(cAttrib)
+ self.colorVboData.setVertexAttrib(cAttrib)
+ else:
+ gl.glDisableVertexAttribArray(cAttrib)
+ gl.glVertexAttrib4f(cAttrib, *self.color)
+
+ xAttrib = prog.attributes['xPos']
+ gl.glEnableVertexAttribArray(xAttrib)
+ self.xVboData.setVertexAttrib(xAttrib)
+
+ yAttrib = prog.attributes['yPos']
+ gl.glEnableVertexAttribArray(yAttrib)
+ self.yVboData.setVertexAttrib(yAttrib)
+
+ gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
+
+ gl.glUseProgram(0)
+
+
+# error bars ##################################################################
+
+class _ErrorBars(object):
+ """Display errors bars.
+
+ This is using its own VBO as opposed to fill/points/lines.
+ There is no picking on error bars.
+ As is, there is no way to update data and errors, but it handles
+ log scales by removing data <= 0 and clipping error bars to positive
+ range.
+
+ It uses 2 vertices per error bars and uses :class:`_Lines2D` to
+ render error bars and :class:`_Points2D` to render the ends.
+ """
+
+ def __init__(self, xData, yData, xError, yError,
+ xMin, yMin,
+ color=(0., 0., 0., 1.)):
+ """Initialization.
+
+ :param numpy.ndarray xData: X coordinates of the data.
+ :param numpy.ndarray yData: Y coordinates of the data.
+ :param xError: The absolute error on the X axis.
+ :type xError: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for negative errors,
+ row 1 for positive errors.
+ :param yError: The absolute error on the Y axis.
+ :type yError: A float, or a numpy.ndarray of float32. See xError.
+ :param float xMin: The min X value already computed by GLPlotCurve2D.
+ :param float yMin: The min Y value already computed by GLPlotCurve2D.
+ :param color: The color to use for both lines and ending points.
+ :type color: tuple of 4 floats
+ """
+ self._attribs = None
+ self._isXLog, self._isYLog = False, False
+ self._xMin, self._yMin = xMin, yMin
+
+ if xError is not None or yError is not None:
+ assert len(xData) == len(yData)
+ self._xData = numpy.array(
+ xData, order='C', dtype=numpy.float32, copy=False)
+ self._yData = numpy.array(
+ yData, order='C', dtype=numpy.float32, copy=False)
+
+ # This also works if xError, yError is a float/int
+ self._xError = numpy.array(
+ xError, order='C', dtype=numpy.float32, copy=False)
+ self._yError = numpy.array(
+ yError, order='C', dtype=numpy.float32, copy=False)
+ else:
+ self._xData, self._yData = None, None
+ self._xError, self._yError = None, None
+
+ self._lines = _Lines2D(None, None, color=color, drawMode=gl.GL_LINES)
+ self._xErrPoints = _Points2D(None, None, color=color, marker=V_LINE)
+ self._yErrPoints = _Points2D(None, None, color=color, marker=H_LINE)
+
+ def _positiveValueFilter(self, onlyXPos, onlyYPos):
+ """Filter data (x, y) and errors (xError, yError) to remove
+ negative and null data values on required axis (onlyXPos, onlyYPos).
+
+ Returned arrays might be NOT contiguous.
+
+ :return: Filtered xData, yData, xError and yError arrays.
+ """
+ if ((not onlyXPos or self._xMin > 0.) and
+ (not onlyYPos or self._yMin > 0.)):
+ # No need to filter, all values are > 0 on log axes
+ return self._xData, self._yData, self._xError, self._yError
+
+ _logger.warning(
+ 'Removing values <= 0 of curve with error bars on a log axis.')
+
+ x, y = self._xData, self._yData
+ xError, yError = self._xError, self._yError
+
+ # First remove negative data
+ if onlyXPos and onlyYPos:
+ mask = (x > 0.) & (y > 0.)
+ elif onlyXPos:
+ mask = x > 0.
+ else: # onlyYPos
+ mask = y > 0.
+ x, y = x[mask], y[mask]
+
+ # Remove corresponding values from error arrays
+ if xError is not None and xError.size != 1:
+ if len(xError.shape) == 1:
+ xError = xError[mask]
+ else: # 2 rows
+ xError = xError[:, mask]
+ if yError is not None and yError.size != 1:
+ if len(yError.shape) == 1:
+ yError = yError[mask]
+ else: # 2 rows
+ yError = yError[:, mask]
+
+ return x, y, xError, yError
+
+ def _buildVertices(self, isXLog, isYLog):
+ """Generates error bars vertices according to log scales."""
+ xData, yData, xError, yError = self._positiveValueFilter(
+ isXLog, isYLog)
+
+ nbLinesPerDataPts = 1 if xError is not None else 0
+ nbLinesPerDataPts += 1 if yError is not None else 0
+
+ nbDataPts = len(xData)
+
+ # interleave coord+error, coord-error.
+ # xError vertices first if any, then yError vertices if any.
+ xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+ yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+
+ if xError is not None: # errors on the X axis
+ if len(xError.shape) == 2:
+ xErrorMinus, xErrorPlus = xError[0], xError[1]
+ else:
+ # numpy arrays of len 1 or len(xData)
+ xErrorMinus, xErrorPlus = xError, xError
+
+ # Interleave vertices for xError
+ endXError = 2 * nbDataPts
+ xCoords[0:endXError-1:2] = xData + xErrorPlus
+
+ minValues = xData - xErrorMinus
+ if isXLog:
+ # Clip min bounds to positive value
+ minValues[minValues <= 0] = FLOAT32_MINPOS
+ xCoords[1:endXError:2] = minValues
+
+ yCoords[0:endXError-1:2] = yData
+ yCoords[1:endXError:2] = yData
+ else:
+ endXError = 0
+
+ if yError is not None: # errors on the Y axis
+ if len(yError.shape) == 2:
+ yErrorMinus, yErrorPlus = yError[0], yError[1]
+ else:
+ # numpy arrays of len 1 or len(yData)
+ yErrorMinus, yErrorPlus = yError, yError
+
+ # Interleave vertices for yError
+ xCoords[endXError::2] = xData
+ xCoords[endXError+1::2] = xData
+ yCoords[endXError::2] = yData + yErrorPlus
+ minValues = yData - yErrorMinus
+ if isYLog:
+ # Clip min bounds to positive value
+ minValues[minValues <= 0] = FLOAT32_MINPOS
+ yCoords[endXError+1::2] = minValues
+
+ return xCoords, yCoords
+
+ def prepare(self, isXLog, isYLog):
+ if self._xData is None:
+ return
+
+ if self._isXLog != isXLog or self._isYLog != isYLog:
+ # Log state has changed
+ self._isXLog, self._isYLog = isXLog, isYLog
+
+ self.discard() # discard existing VBOs
+
+ if self._attribs is None:
+ xCoords, yCoords = self._buildVertices(isXLog, isYLog)
+
+ xAttrib, yAttrib = vertexBuffer((xCoords, yCoords))
+ self._attribs = xAttrib, yAttrib
+
+ self._lines.xVboData, self._lines.yVboData = xAttrib, yAttrib
+
+ # Set xError points using the same VBO as lines
+ self._xErrPoints.xVboData = xAttrib.copy()
+ self._xErrPoints.xVboData.size //= 2
+ self._xErrPoints.yVboData = yAttrib.copy()
+ self._xErrPoints.yVboData.size //= 2
+
+ # Set yError points using the same VBO as lines
+ self._yErrPoints.xVboData = xAttrib.copy()
+ self._yErrPoints.xVboData.size //= 2
+ self._yErrPoints.xVboData.offset += (xAttrib.itemsize *
+ xAttrib.size // 2)
+ self._yErrPoints.yVboData = yAttrib.copy()
+ self._yErrPoints.yVboData.size //= 2
+ self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
+ yAttrib.size // 2)
+
+ def render(self, matrix, isXLog, isYLog):
+ if self._attribs is not None:
+ self._lines.render(matrix, isXLog, isYLog)
+ self._xErrPoints.render(matrix, isXLog, isYLog)
+ self._yErrPoints.render(matrix, isXLog, isYLog)
+
+ def discard(self):
+ if self._attribs is not None:
+ self._lines.xVboData, self._lines.yVboData = None, None
+ self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
+ self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
+ self._attribs[0].vbo.discard()
+ self._attribs = None
+
+
+# curves ######################################################################
+
+def _proxyProperty(*componentsAttributes):
+ """Create a property to access an attribute of attribute(s).
+ Useful for composition.
+ Supports multiple components this way:
+ getter returns the first found, setter sets all
+ """
+ def getter(self):
+ for compName, attrName in componentsAttributes:
+ try:
+ component = getattr(self, compName)
+ except AttributeError:
+ pass
+ else:
+ return getattr(component, attrName)
+
+ def setter(self, value):
+ for compName, attrName in componentsAttributes:
+ component = getattr(self, compName)
+ setattr(component, attrName, value)
+ return property(getter, setter)
+
+
+class GLPlotCurve2D(object):
+ def __init__(self, xData, yData, colorData=None,
+ xError=None, yError=None,
+ lineStyle=None, lineColor=None,
+ lineWidth=None, lineDashPeriod=None,
+ marker=None, markerColor=None, markerSize=None,
+ fillColor=None):
+ self._isXLog = False
+ self._isYLog = False
+ self.xData, self.yData, self.colorData = xData, yData, colorData
+
+ if fillColor is not None:
+ self.fill = _Fill2D(color=fillColor)
+ else:
+ self.fill = None
+
+ # Compute x bounds
+ if xError is None:
+ result = min_max(xData, min_positive=True)
+ self.xMin = result.minimum
+ self.xMinPos = result.min_positive
+ self.xMax = result.maximum
+ else:
+ # Takes the error into account
+ if hasattr(xError, 'shape') and len(xError.shape) == 2:
+ xErrorPlus, xErrorMinus = xError[0], xError[1]
+ else:
+ xErrorPlus, xErrorMinus = xError, xError
+ result = min_max(xData - xErrorMinus, min_positive=True)
+ self.xMin = result.minimum
+ self.xMinPos = result.min_positive
+ self.xMax = (xData + xErrorPlus).max()
+
+ # Compute y bounds
+ if yError is None:
+ result = min_max(yData, min_positive=True)
+ self.yMin = result.minimum
+ self.yMinPos = result.min_positive
+ self.yMax = result.maximum
+ else:
+ # Takes the error into account
+ if hasattr(yError, 'shape') and len(yError.shape) == 2:
+ yErrorPlus, yErrorMinus = yError[0], yError[1]
+ else:
+ yErrorPlus, yErrorMinus = yError, yError
+ result = min_max(yData - yErrorMinus, min_positive=True)
+ self.yMin = result.minimum
+ self.yMinPos = result.min_positive
+ self.yMax = (yData + yErrorPlus).max()
+
+ self._errorBars = _ErrorBars(xData, yData, xError, yError,
+ self.xMin, self.yMin)
+
+ kwargs = {'style': lineStyle}
+ if lineColor is not None:
+ kwargs['color'] = lineColor
+ if lineWidth is not None:
+ kwargs['width'] = lineWidth
+ if lineDashPeriod is not None:
+ kwargs['dashPeriod'] = lineDashPeriod
+ self.lines = _Lines2D(**kwargs)
+
+ kwargs = {'marker': marker}
+ if markerColor is not None:
+ kwargs['color'] = markerColor
+ if markerSize is not None:
+ kwargs['size'] = markerSize
+ self.points = _Points2D(**kwargs)
+
+ xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData'))
+
+ yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData'))
+
+ colorVboData = _proxyProperty(('lines', 'colorVboData'),
+ ('points', 'colorVboData'))
+
+ useColorVboData = _proxyProperty(('lines', 'useColorVboData'),
+ ('points', 'useColorVboData'))
+
+ distVboData = _proxyProperty(('lines', 'distVboData'))
+
+ lineStyle = _proxyProperty(('lines', 'style'))
+
+ lineColor = _proxyProperty(('lines', 'color'))
+
+ lineWidth = _proxyProperty(('lines', 'width'))
+
+ lineDashPeriod = _proxyProperty(('lines', 'dashPeriod'))
+
+ marker = _proxyProperty(('points', 'marker'))
+
+ markerColor = _proxyProperty(('points', 'color'))
+
+ markerSize = _proxyProperty(('points', 'size'))
+
+ @classmethod
+ def init(cls):
+ _Lines2D.init()
+ _Points2D.init()
+
+ @staticmethod
+ def _logFilterData(x, y, color=None, xLog=False, yLog=False):
+ # Copied from Plot.py
+ if xLog and yLog:
+ idx = numpy.nonzero((x > 0) & (y > 0))[0]
+ x = numpy.take(x, idx)
+ y = numpy.take(y, idx)
+ elif yLog:
+ idx = numpy.nonzero(y > 0)[0]
+ x = numpy.take(x, idx)
+ y = numpy.take(y, idx)
+ elif xLog:
+ idx = numpy.nonzero(x > 0)[0]
+ x = numpy.take(x, idx)
+ y = numpy.take(y, idx)
+ else:
+ idx = None
+
+ if idx is not None and isinstance(color, numpy.ndarray):
+ colors = numpy.zeros((x.size, 4), color.dtype)
+ colors[:, 0] = color[idx, 0]
+ colors[:, 1] = color[idx, 1]
+ colors[:, 2] = color[idx, 2]
+ colors[:, 3] = color[idx, 3]
+ else:
+ colors = color
+ return x, y, colors
+
+ def prepare(self, isXLog, isYLog):
+ # init only supports updating isXLog, isYLog
+ xData, yData, colorData = self.xData, self.yData, self.colorData
+
+ if self._isXLog != isXLog or self._isYLog != isYLog:
+ # Log state has changed
+ self._isXLog, self._isYLog = isXLog, isYLog
+
+ # Check if data <= 0. with log scale
+ if (isXLog and self.xMin <= 0.) or (isYLog and self.yMin <= 0.):
+ # Filtering data is needed
+ xData, yData, colorData = self._logFilterData(
+ self.xData, self.yData, self.colorData,
+ self._isXLog, self._isYLog)
+
+ self.discard() # discard existing VBOs
+
+ if self.xVboData is None:
+ xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
+ if self.lineStyle in (DASHED, DASHDOT, DOTTED):
+ dists = _distancesFromArrays(xData, yData)
+ if self.colorData is None:
+ xAttrib, yAttrib, dAttrib = vertexBuffer(
+ (xData, yData, dists),
+ prefix=(1, 1, 0), suffix=(1, 1, 0))
+ else:
+ xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
+ (xData, yData, colorData, dists),
+ prefix=(1, 1, 0, 0), suffix=(1, 1, 0, 0))
+ elif self.colorData is None:
+ xAttrib, yAttrib = vertexBuffer(
+ (xData, yData), prefix=(1, 1), suffix=(1, 1))
+ else:
+ xAttrib, yAttrib, cAttrib = vertexBuffer(
+ (xData, yData, colorData), prefix=(1, 1, 0))
+
+ # Shrink VBO
+ self.xVboData = xAttrib.copy()
+ self.xVboData.size -= 2
+ self.xVboData.offset += xAttrib.itemsize
+
+ self.yVboData = yAttrib.copy()
+ self.yVboData.size -= 2
+ self.yVboData.offset += yAttrib.itemsize
+
+ if cAttrib is not None and colorData.dtype.kind == 'u':
+ cAttrib.normalisation = True # Normalise uint to [0, 1]
+ self.colorVboData = cAttrib
+ self.useColorVboData = cAttrib is not None
+ self.distVboData = dAttrib
+
+ if self.fill is not None:
+ xData = xData.reshape(xData.size, 1)
+ zero = numpy.array((1e-32,), dtype=self.yData.dtype)
+
+ # Add one point before data: (x0, 0.)
+ xAttrib.vbo.update(xData[0], xAttrib.offset,
+ xData[0].itemsize)
+ yAttrib.vbo.update(zero, yAttrib.offset, zero.itemsize)
+
+ # Add one point after data: (xN, 0.)
+ xAttrib.vbo.update(xData[-1],
+ xAttrib.offset +
+ (xAttrib.size - 1) * xAttrib.itemsize,
+ xData[-1].itemsize)
+ yAttrib.vbo.update(zero,
+ yAttrib.offset +
+ (yAttrib.size - 1) * yAttrib.itemsize,
+ zero.itemsize)
+
+ self.fill.xFillVboData = xAttrib
+ self.fill.yFillVboData = yAttrib
+ self.fill.xMin, self.fill.yMin = self.xMin, self.yMin
+ self.fill.xMax, self.fill.yMax = self.xMax, self.yMax
+
+ self._errorBars.prepare(isXLog, isYLog)
+
+ def render(self, matrix, isXLog, isYLog):
+ self.prepare(isXLog, isYLog)
+ if self.fill is not None:
+ self.fill.render(matrix, isXLog, isYLog)
+ self._errorBars.render(matrix, isXLog, isYLog)
+ self.lines.render(matrix, isXLog, isYLog)
+ self.points.render(matrix, isXLog, isYLog)
+
+ def discard(self):
+ if self.xVboData is not None:
+ self.xVboData.vbo.discard()
+
+ self.xVboData = None
+ self.yVboData = None
+ self.colorVboData = None
+ self.distVboData = None
+
+ self._errorBars.discard()
+
+ def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
+ """Perform picking on the curve according to its rendering.
+
+ The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax].
+
+ In case a segment between 2 points with indices i, i+1 is picked,
+ only its lower index end point (i.e., i) is added to the result.
+ In case an end point with index i is picked it is added to the result,
+ and the segment [i-1, i] is not tested for picking.
+
+ :return: The indices of the picked data
+ :rtype: list of int
+ """
+ if (self.marker is None and self.lineStyle is None) or \
+ self.xMin > xPickMax or xPickMin > self.xMax or \
+ self.yMin > yPickMax or yPickMin > self.yMax:
+ # Note: With log scale the bounding box is too large if
+ # some data <= 0.
+ return None
+
+ elif self.lineStyle is not None:
+ # Using Cohen-Sutherland algorithm for line clipping
+ codes = ((self.yData > yPickMax) << 3) | \
+ ((self.yData < yPickMin) << 2) | \
+ ((self.xData > xPickMax) << 1) | \
+ (self.xData < xPickMin)
+
+ # Add all points that are inside the picking area
+ indices = numpy.nonzero(codes == 0)[0].tolist()
+
+ # Segment that might cross the area with no end point inside it
+ segToTestIdx = numpy.nonzero((codes[:-1] != 0) &
+ (codes[1:] != 0) &
+ ((codes[:-1] & codes[1:]) == 0))[0]
+
+ TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
+
+ for index in segToTestIdx:
+ if index not in indices:
+ x0, y0 = self.xData[index], self.yData[index]
+ x1, y1 = self.xData[index + 1], self.yData[index + 1]
+ code1 = codes[index + 1]
+
+ # check for crossing with horizontal bounds
+ # y0 == y1 is a never event:
+ # => pt0 and pt1 in same vertical area are not in segToTest
+ if code1 & TOP:
+ x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0)
+ elif code1 & BOTTOM:
+ x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0)
+ else:
+ x = None # No horizontal bounds intersection test
+
+ if x is not None and xPickMin <= x <= xPickMax:
+ # Intersection
+ indices.append(index)
+
+ else:
+ # check for crossing with vertical bounds
+ # x0 == x1 is a never event (see remark for y)
+ if code1 & RIGHT:
+ y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0)
+ elif code1 & LEFT:
+ y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0)
+ else:
+ y = None # No vertical bounds intersection test
+
+ if y is not None and yPickMin <= y <= yPickMax:
+ # Intersection
+ indices.append(index)
+
+ indices.sort()
+
+ else:
+ indices = numpy.nonzero((self.xData >= xPickMin) &
+ (self.xData <= xPickMax) &
+ (self.yData >= yPickMin) &
+ (self.yData <= yPickMax))[0].tolist()
+
+ return indices
diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py
new file mode 100644
index 0000000..367419c
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -0,0 +1,1039 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This modules provides the rendering of plot titles, axes and grid.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+# TODO
+# keep aspect ratio managed here?
+# smarter dirty flag handling?
+
+import math
+import weakref
+import logging
+from collections import namedtuple
+
+import numpy
+
+from ...._glutils import gl, Program
+from ..._utils import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
+from .GLSupport import mat4Ortho
+from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
+from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
+
+
+_logger = logging.getLogger(__name__)
+
+
+# PlotAxis ####################################################################
+
+class PlotAxis(object):
+ """Represents a 1D axis of the plot.
+ This class is intended to be used with :class:`GLPlotFrame`.
+ """
+
+ def __init__(self, plot,
+ tickLength=(0., 0.),
+ labelAlign=CENTER, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=CENTER,
+ titleRotate=0, titleOffset=(0., 0.)):
+ self._ticks = None
+
+ self._plot = weakref.ref(plot)
+
+ self._isLog = False
+ self._dataRange = 1., 100.
+ self._displayCoords = (0., 0.), (1., 0.)
+ self._title = ''
+
+ self._tickLength = tickLength
+ self._labelAlign = labelAlign
+ self._labelVAlign = labelVAlign
+ self._titleAlign = titleAlign
+ self._titleVAlign = titleVAlign
+ self._titleRotate = titleRotate
+ self._titleOffset = titleOffset
+
+ @property
+ def dataRange(self):
+ """The range of the data represented on the axis as a tuple
+ of 2 floats: (min, max)."""
+ return self._dataRange
+
+ @dataRange.setter
+ def dataRange(self, dataRange):
+ assert len(dataRange) == 2
+ assert dataRange[0] <= dataRange[1]
+ dataRange = float(dataRange[0]), float(dataRange[1])
+
+ if dataRange != self._dataRange:
+ self._dataRange = dataRange
+ self._dirtyTicks()
+
+ @property
+ def isLog(self):
+ """Whether the axis is using a log10 scale or not as a bool."""
+ return self._isLog
+
+ @isLog.setter
+ def isLog(self, isLog):
+ isLog = bool(isLog)
+ if isLog != self._isLog:
+ self._isLog = isLog
+ self._dirtyTicks()
+
+ @property
+ def displayCoords(self):
+ """The coordinates of the start and end points of the axis
+ in display space (i.e., in pixels) as a tuple of 2 tuples of
+ 2 floats: ((x0, y0), (x1, y1)).
+ """
+ return self._displayCoords
+
+ @displayCoords.setter
+ def displayCoords(self, displayCoords):
+ assert len(displayCoords) == 2
+ assert len(displayCoords[0]) == 2
+ assert len(displayCoords[1]) == 2
+ displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1])
+ if displayCoords != self._displayCoords:
+ self._displayCoords = displayCoords
+ self._dirtyTicks()
+
+ @property
+ def title(self):
+ """The text label associated with this axis as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+
+ plot = self._plot()
+ if plot is not None:
+ plot._dirty()
+
+ @property
+ def ticks(self):
+ """Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
+ if self._ticks is None:
+ self._ticks = tuple(self._ticksGenerator())
+ return self._ticks
+
+ def getVerticesAndLabels(self):
+ """Create the list of vertices for axis and associated text labels.
+
+ :returns: A tuple: List of 2D line vertices, List of Text2D labels.
+ """
+ vertices = list(self.displayCoords) # Add start and end points
+ labels = []
+ tickLabelsSize = [0., 0.]
+
+ xTickLength, yTickLength = self._tickLength
+ for (xPixel, yPixel), dataPos, text in self.ticks:
+ if text is None:
+ tickScale = 0.5
+ else:
+ tickScale = 1.
+
+ label = Text2D(text=text,
+ x=xPixel - xTickLength,
+ y=yPixel - yTickLength,
+ align=self._labelAlign,
+ valign=self._labelVAlign)
+
+ width, height = label.size
+ if width > tickLabelsSize[0]:
+ tickLabelsSize[0] = width
+ if height > tickLabelsSize[1]:
+ tickLabelsSize[1] = height
+
+ labels.append(label)
+
+ vertices.append((xPixel, yPixel))
+ vertices.append((xPixel + tickScale * xTickLength,
+ yPixel + tickScale * yTickLength))
+
+ (x0, y0), (x1, y1) = self.displayCoords
+ xAxisCenter = 0.5 * (x0 + x1)
+ yAxisCenter = 0.5 * (y0 + y1)
+
+ xOffset, yOffset = self._titleOffset
+
+ # Adaptative title positioning:
+ # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
+ # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm
+ # xOffset -= 3 * xTickLength
+ # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
+ # yOffset -= 3 * yTickLength
+
+ axisTitle = Text2D(text=self.title,
+ x=xAxisCenter + xOffset,
+ y=yAxisCenter + yOffset,
+ align=self._titleAlign,
+ valign=self._titleVAlign,
+ rotate=self._titleRotate)
+ labels.append(axisTitle)
+
+ return vertices, labels
+
+ def _dirtyTicks(self):
+ """Mark ticks as dirty and notify listener (i.e., background)."""
+ self._ticks = None
+ plot = self._plot()
+ if plot is not None:
+ plot._dirty()
+
+ @staticmethod
+ def _frange(start, stop, step):
+ """range for float (including stop)."""
+ while start <= stop:
+ yield start
+ start += step
+
+ def _ticksGenerator(self):
+ """Generator of ticks as tuples:
+ ((x, y) in display, dataPos, textLabel).
+ """
+ dataMin, dataMax = self.dataRange
+ if self.isLog and dataMin <= 0.:
+ _logger.warning(
+ 'Getting ticks while isLog=True and dataRange[0]<=0.')
+ dataMin = 1.
+ if dataMax < dataMin:
+ dataMax = 1.
+
+ if dataMin != dataMax: # data range is not null
+ (x0, y0), (x1, y1) = self.displayCoords
+
+ if self.isLog:
+ logMin, logMax = math.log10(dataMin), math.log10(dataMax)
+ tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax)
+
+ xScale = (x1 - x0) / (logMax - logMin)
+ yScale = (y1 - y0) / (logMax - logMin)
+
+ for logPos in self._frange(tickMin, tickMax, step):
+ if logMin <= logPos <= logMax:
+ dataPos = 10 ** logPos
+ xPixel = x0 + (logPos - logMin) * xScale
+ yPixel = y0 + (logPos - logMin) * yScale
+ text = '1e%+03d' % logPos
+ yield ((xPixel, yPixel), dataPos, text)
+
+ if step == 1:
+ ticks = list(self._frange(tickMin, tickMax, step))[:-1]
+ for logPos in ticks:
+ dataOrigPos = 10 ** logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if dataMin <= dataPos <= dataMax:
+ logSubPos = math.log10(dataPos)
+ xPixel = x0 + (logSubPos - logMin) * xScale
+ yPixel = y0 + (logSubPos - logMin) * yScale
+ yield ((xPixel, yPixel), dataPos, None)
+
+ else:
+ xScale = (x1 - x0) / (dataMax - dataMin)
+ yScale = (y1 - y0) / (dataMax - dataMin)
+
+ nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2))
+
+ # Density of 1.3 label per 92 pixels
+ # i.e., 1.3 label per inch on a 92 dpi screen
+ tickMin, tickMax, step, nbFrac = niceNumbersAdaptative(
+ dataMin, dataMax, nbPixels, 1.3 / 92)
+
+ for dataPos in self._frange(tickMin, tickMax, step):
+ if dataMin <= dataPos <= dataMax:
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+
+ if nbFrac == 0:
+ text = '%g' % dataPos
+ else:
+ text = ('%.' + str(nbFrac) + 'f') % dataPos
+ yield ((xPixel, yPixel), dataPos, text)
+
+
+# GLPlotFrame #################################################################
+
+class GLPlotFrame(object):
+ """Base class for rendering a 2D frame surrounded by axes."""
+
+ _TICK_LENGTH_IN_PIXELS = 5
+ _LINE_WIDTH = 1
+
+ _SHADERS = {
+ 'vertex': """
+ attribute vec2 position;
+ uniform mat4 matrix;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ 'fragment': """
+ uniform vec4 color;
+ uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
+
+ void main(void) {
+ if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+ }
+
+ _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom'))
+
+ def __init__(self, margins):
+ """
+ :param margins: The margins around plot area for axis and labels.
+ :type margins: dict with 'left', 'right', 'top', 'bottom' keys and
+ values as ints.
+ """
+ self._renderResources = None
+
+ self._margins = self._Margins(**margins)
+
+ self.axes = [] # List of PlotAxis to be updated by subclasses
+
+ self._grid = False
+ self._size = 0., 0.
+ self._title = ''
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return self._renderResources is None
+
+ GRID_NONE = 0
+ GRID_MAIN_TICKS = 1
+ GRID_SUB_TICKS = 2
+ GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
+
+ @property
+ def margins(self):
+ """Margins in pixels around the plot."""
+ return self._margins
+
+ @property
+ def grid(self):
+ """Grid display mode:
+ - 0: No grid.
+ - 1: Grid on main ticks.
+ - 2: Grid on sub-ticks for log scale axes.
+ - 3: Grid on main and sub ticks."""
+ return self._grid
+
+ @grid.setter
+ def grid(self, grid):
+ assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS,
+ self.GRID_SUB_TICKS, self.GRID_ALL_TICKS)
+ if grid != self._grid:
+ self._grid = grid
+ self._dirty()
+
+ @property
+ def size(self):
+ """Size in pixels of the plot area including margins."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ size = tuple(size)
+ if size != self._size:
+ self._size = size
+ self._dirty()
+
+ @property
+ def plotOrigin(self):
+ """Plot area origin (left, top) in widget coordinates in pixels."""
+ return self.margins.left, self.margins.top
+
+ @property
+ def plotSize(self):
+ """Plot area size (width, height) in pixels."""
+ w, h = self.size
+ w -= self.margins.left + self.margins.right
+ h -= self.margins.top + self.margins.bottom
+ return w, h
+
+ @property
+ def title(self):
+ """Main title as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirty()
+
+ # In-place update
+ # if self._renderResources is not None:
+ # self._renderResources[-1][-1].text = title
+
+ def _dirty(self):
+ # When Text2D require discard we need to handle it
+ self._renderResources = None
+
+ def _buildGridVertices(self):
+ if self._grid == self.GRID_NONE:
+ return []
+
+ elif self._grid == self.GRID_MAIN_TICKS:
+ def test(text):
+ return text is not None
+ elif self._grid == self.GRID_SUB_TICKS:
+ def test(text):
+ return text is None
+ elif self._grid == self.GRID_ALL_TICKS:
+ def test(_):
+ return True
+ else:
+ logging.warning('Wrong grid mode: %d' % self._grid)
+ return []
+
+ return self._buildGridVerticesWithTest(test)
+
+ def _buildGridVerticesWithTest(self, test):
+ """Override in subclass to generate grid vertices"""
+ return []
+
+ def _buildVerticesAndLabels(self):
+ # To fill with copy of axes lists
+ vertices = []
+ labels = []
+
+ for axis in self.axes:
+ axisVertices, axisLabels = axis.getVerticesAndLabels()
+ vertices += axisVertices
+ labels += axisLabels
+
+ vertices = numpy.array(vertices, dtype=numpy.float32)
+
+ # Add main title
+ xTitle = (self.size[0] + self.margins.left -
+ self.margins.right) // 2
+ yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
+ labels.append(Text2D(text=self.title,
+ x=xTitle,
+ y=yTitle,
+ align=CENTER,
+ valign=BOTTOM))
+
+ # grid
+ gridVertices = numpy.array(self._buildGridVertices(),
+ dtype=numpy.float32)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ _program = Program(
+ _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
+
+ def render(self):
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj)
+ gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ for label in labels:
+ label.render(matProj)
+
+ def renderGrid(self):
+ if self._grid == self.GRID_NONE:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matProj)
+ gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, gridVertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
+
+
+# GLPlotFrame2D ###############################################################
+
+class GLPlotFrame2D(GLPlotFrame):
+ def __init__(self, margins):
+ """
+ :param margins: The margins around plot area for axis and labels.
+ :type margins: dict with 'left', 'right', 'top', 'bottom' keys and
+ values as ints.
+ """
+ super(GLPlotFrame2D, self).__init__(margins)
+ self.axes.append(PlotAxis(self,
+ tickLength=(0., -5.),
+ labelAlign=CENTER, labelVAlign=TOP,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=0,
+ titleOffset=(0, self.margins.bottom // 2)))
+
+ self._x2AxisCoords = ()
+
+ self.axes.append(PlotAxis(self,
+ tickLength=(5., 0.),
+ labelAlign=RIGHT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=BOTTOM,
+ titleRotate=ROTATE_270,
+ titleOffset=(-3 * self.margins.left // 4,
+ 0)))
+
+ self._y2Axis = PlotAxis(self,
+ tickLength=(-5., 0.),
+ labelAlign=LEFT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=ROTATE_270,
+ titleOffset=(3 * self.margins.right // 4,
+ 0))
+
+ self._isYAxisInverted = False
+
+ self._dataRanges = {
+ 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)}
+
+ self._baseVectors = (1., 0.), (0., 1.)
+
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ def _dirty(self):
+ super(GLPlotFrame2D, self)._dirty()
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return (super(GLPlotFrame2D, self).isDirty or
+ self._transformedDataRanges is None or
+ self._transformedDataProjMat is None or
+ self._transformedDataY2ProjMat is None)
+
+ @property
+ def xAxis(self):
+ return self.axes[0]
+
+ @property
+ def yAxis(self):
+ return self.axes[1]
+
+ @property
+ def y2Axis(self):
+ return self._y2Axis
+
+ @property
+ def isY2Axis(self):
+ """Whether to display the left Y axis or not."""
+ return len(self.axes) == 3
+
+ @isY2Axis.setter
+ def isY2Axis(self, isY2Axis):
+ if isY2Axis != self.isY2Axis:
+ if isY2Axis:
+ self.axes.append(self._y2Axis)
+ else:
+ self.axes = self.axes[:2]
+
+ self._dirty()
+
+ @property
+ def isYAxisInverted(self):
+ """Whether Y axes are inverted or not as a bool."""
+ return self._isYAxisInverted
+
+ @isYAxisInverted.setter
+ def isYAxisInverted(self, value):
+ value = bool(value)
+ if value != self._isYAxisInverted:
+ self._isYAxisInverted = value
+ self._dirty()
+
+ DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.)
+ """Values of baseVectors for orthogonal axes."""
+
+ @property
+ def baseVectors(self):
+ """Coordinates of the X and Y axes in the orthogonal plot coords.
+
+ Raises ValueError if corresponding matrix is singular.
+
+ 2 tuples of 2 floats: (xx, xy), (yx, yy)
+ """
+ return self._baseVectors
+
+ @baseVectors.setter
+ def baseVectors(self, baseVectors):
+ self._dirty()
+
+ (xx, xy), (yx, yy) = baseVectors
+ vectors = (float(xx), float(xy)), (float(yx), float(yy))
+
+ det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1])
+ if det == 0.:
+ raise ValueError("Singular matrix for base vectors: " +
+ str(vectors))
+
+ if vectors != self._baseVectors:
+ self._baseVectors = vectors
+ self._dirty()
+
+ @property
+ def dataRanges(self):
+ """Ranges of data visible in the plot on x, y and y2 axes.
+
+ This is different to the axes range when axes are not orthogonal.
+
+ Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+ """
+ return self._DataRanges(self._dataRanges['x'],
+ self._dataRanges['y'],
+ self._dataRanges['y2'])
+
+ @staticmethod
+ def _clipToSafeRange(min_, max_, isLog):
+ # Clip range if needed
+ minLimit = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN
+ min_ = numpy.clip(min_, minLimit, FLOAT32_SAFE_MAX)
+ max_ = numpy.clip(max_, minLimit, FLOAT32_SAFE_MAX)
+ assert min_ < max_
+ return min_, max_
+
+ def setDataRanges(self, x=None, y=None, y2=None):
+ """Set data range over each axes.
+
+ The provided ranges are clipped to possible values
+ (i.e., 32 float range + positive range for log scale).
+
+ :param x: (min, max) data range over X axis
+ :param y: (min, max) data range over Y axis
+ :param y2: (min, max) data range over Y2 axis
+ """
+ if x is not None:
+ self._dataRanges['x'] = \
+ self._clipToSafeRange(x[0], x[1], self.xAxis.isLog)
+
+ if y is not None:
+ self._dataRanges['y'] = \
+ self._clipToSafeRange(y[0], y[1], self.yAxis.isLog)
+
+ if y2 is not None:
+ self._dataRanges['y2'] = \
+ self._clipToSafeRange(y2[0], y2[1], self.y2Axis.isLog)
+
+ self.xAxis.dataRange = self._dataRanges['x']
+ self.yAxis.dataRange = self._dataRanges['y']
+ self.y2Axis.dataRange = self._dataRanges['y2']
+
+ _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2'))
+
+ @property
+ def transformedDataRanges(self):
+ """Bounds of the displayed area in transformed data coordinates
+ (i.e., log scale applied if any as well as skew)
+
+ 3-tuple of 2-tuple (min, max) for each axis: x, y, y2.
+ """
+ if self._transformedDataRanges is None:
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges
+
+ if self.xAxis.isLog:
+ try:
+ xMin = math.log10(xMin)
+ except ValueError:
+ _logger.info('xMin: warning log10(%f)', xMin)
+ xMin = 0.
+ try:
+ xMax = math.log10(xMax)
+ except ValueError:
+ _logger.info('xMax: warning log10(%f)', xMax)
+ xMax = 0.
+
+ if self.yAxis.isLog:
+ try:
+ yMin = math.log10(yMin)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', yMin)
+ yMin = 0.
+ try:
+ yMax = math.log10(yMax)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', yMax)
+ yMax = 0.
+
+ try:
+ y2Min = math.log10(y2Min)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', y2Min)
+ y2Min = 0.
+ try:
+ y2Max = math.log10(y2Max)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', y2Max)
+ y2Max = 0.
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+
+ corners = [(xMin, yMin), (xMin, yMax),
+ (xMax, yMin), (xMax, yMax),
+ (xMin, y2Min), (xMin, y2Max),
+ (xMax, y2Min), (xMax, y2Max)]
+
+ corners = numpy.array(
+ [numpy.dot(skew_mat, corner) for corner in corners],
+ dtype=numpy.float32)
+ xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
+ yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max()
+ y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max()
+
+ self._transformedDataRanges = self._DataRanges(
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+
+ return self._transformedDataRanges
+
+ @property
+ def transformedDataProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ yMin, yMax = self.transformedDataRanges.y
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ mat = mat * numpy.matrix((
+ (xx, yx, 0., 0.),
+ (xy, yy, 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+ self._transformedDataProjMat = mat
+
+ return self._transformedDataProjMat
+
+ @property
+ def transformedDataY2ProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+ for the 2nd Y axis
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataY2ProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ y2Min, y2Max = self.transformedDataRanges.y2
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ mat = mat * numpy.matrix((
+ (xx, yx, 0., 0.),
+ (xy, yy, 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+ self._transformedDataY2ProjMat = mat
+
+ return self._transformedDataY2ProjMat
+
+ def dataToPixel(self, x, y, axis='left'):
+ """Convert data coordinate to widget pixel coordinate.
+ """
+ assert axis in ('left', 'right')
+
+ trBounds = self.transformedDataRanges
+
+ if self.xAxis.isLog:
+ if x < FLOAT32_MINPOS:
+ return None
+ xDataTr = math.log10(x)
+ else:
+ xDataTr = x
+
+ if self.yAxis.isLog:
+ if y < FLOAT32_MINPOS:
+ return None
+ yDataTr = math.log10(y)
+ else:
+ yDataTr = y
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+
+ coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr)))
+ xDataTr, yDataTr = coords
+
+ plotWidth, plotHeight = self.plotSize
+
+ xPixel = int(self.margins.left +
+ plotWidth * (xDataTr - trBounds.x[0]) /
+ (trBounds.x[1] - trBounds.x[0]))
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ yOffset = (plotHeight * (yDataTr - usedAxis[0]) /
+ (usedAxis[1] - usedAxis[0]))
+
+ if self.isYAxisInverted:
+ yPixel = int(self.margins.top + yOffset)
+ else:
+ yPixel = int(self.size[1] - self.margins.bottom - yOffset)
+
+ return xPixel, yPixel
+
+ def pixelToData(self, x, y, axis="left"):
+ """Convert pixel position to data coordinates.
+
+ :param float x: X coord
+ :param float y: Y coord
+ :param str axis: Y axis to use in ('left', 'right')
+ :return: (x, y) position in data coords
+ """
+ assert axis in ("left", "right")
+
+ plotWidth, plotHeight = self.plotSize
+
+ trBounds = self.transformedDataRanges
+
+ xData = (x - self.margins.left + 0.5) / float(plotWidth)
+ xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0])
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ if self.isYAxisInverted:
+ yData = (y - self.margins.top + 0.5) / float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+ else:
+ yData = self.size[1] - self.margins.bottom - y - 0.5
+ yData /= float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+
+ # non-orthogonal axis
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+ skew_mat = numpy.linalg.inv(skew_mat)
+
+ coords = numpy.dot(skew_mat, numpy.array((xData, yData)))
+ xData, yData = coords
+
+ if self.xAxis.isLog:
+ xData = pow(10, xData)
+ if self.yAxis.isLog:
+ yData = pow(10, yData)
+
+ return xData, yData
+
+ def _buildGridVerticesWithTest(self, test):
+ vertices = []
+
+ if self.baseVectors == self.DEFAULT_BASE_VECTORS:
+ for axis in self.axes:
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ vertices.append((xPixel, yPixel))
+ if axis == self.xAxis:
+ vertices.append((xPixel, self.margins.top))
+ elif axis == self.yAxis:
+ vertices.append((self.size[0] - self.margins.right,
+ yPixel))
+ else: # axis == self.y2Axis
+ vertices.append((self.margins.left, yPixel))
+
+ else:
+ # Get plot corners in data coords
+ plotLeft, plotTop = self.plotOrigin
+ plotWidth, plotHeight = self.plotSize
+
+ corners = [(plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop)]
+
+ for axis in self.axes:
+ if axis == self.xAxis:
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = ((cornersInData[0], cornersInData[3]), # top
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[3], cornersInData[2])) # right
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(x0, x1) <= data < max(x0, x1):
+ yIntersect = (data - x0) * \
+ (y1 - y0) / (x1 - x0) + y0
+
+ pixelPos = self.dataToPixel(
+ data, yIntersect)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ else: # y or y2 axes
+ if axis == self.yAxis:
+ axis_name = 'left'
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = (
+ (cornersInData[3], cornersInData[2]), # right
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ else: # axis == self.y2Axis
+ axis_name = 'right'
+ corners = numpy.array([self.pixelToData(
+ x, y, axis='right') for (x, y) in corners])
+ borders = (
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(y0, y1) <= data < max(y0, y1):
+ xIntersect = (data - y0) * \
+ (x1 - x0) / (y1 - y0) + x0
+
+ pixelPos = self.dataToPixel(
+ xIntersect, data, axis=axis_name)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ return vertices
+
+ def _buildVerticesAndLabels(self):
+ width, height = self.size
+
+ xCoords = (self.margins.left - 0.5,
+ width - self.margins.right + 0.5)
+ yCoords = (height - self.margins.bottom + 0.5,
+ self.margins.top - 0.5)
+
+ self.axes[0].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[1], yCoords[0]))
+
+ self._x2AxisCoords = ((xCoords[0], yCoords[1]),
+ (xCoords[1], yCoords[1]))
+
+ if self.isYAxisInverted:
+ # Y axes are inverted, axes coordinates are inverted
+ yCoords = yCoords[1], yCoords[0]
+
+ self.axes[1].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[0], yCoords[1]))
+
+ self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]),
+ (xCoords[1], yCoords[1]))
+
+ super(GLPlotFrame2D, self)._buildVerticesAndLabels()
+
+ vertices, gridVertices, labels = self._renderResources
+
+ # Adds vertices for borders without axis
+ extraVertices = []
+ extraVertices += self._x2AxisCoords
+ if not self.isY2Axis:
+ extraVertices += self._y2Axis.displayCoords
+
+ extraVertices = numpy.array(
+ extraVertices, copy=False, dtype=numpy.float32)
+ vertices = numpy.append(vertices, extraVertices, axis=0)
+
+ self._renderResources = (vertices, gridVertices, labels)
diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py
new file mode 100644
index 0000000..8fff82b
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -0,0 +1,707 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a class to render 2D array as a colormap or RGB(A) image
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl, Program, Texture
+from ..._utils import FLOAT32_MINPOS
+from .GLSupport import mat4Translate, mat4Scale
+from .GLTexture import Image
+
+
+class _GLPlotData2D(object):
+ def __init__(self, data, origin, scale):
+ self.data = data
+ assert len(origin) == 2
+ self.origin = tuple(origin)
+ assert len(scale) == 2
+ self.scale = tuple(scale)
+
+ def pick(self, x, y):
+ if self.xMin <= x <= self.xMax and self.yMin <= y <= self.yMax:
+ ox, oy = self.origin
+ sx, sy = self.scale
+ col = int((x - ox) / sx)
+ row = int((y - oy) / sy)
+ return col, row
+ else:
+ return None
+
+ @property
+ def xMin(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox if sx >= 0. else ox + sx * self.data.shape[1]
+
+ @property
+ def yMin(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy if sy >= 0. else oy + sy * self.data.shape[0]
+
+ @property
+ def xMax(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox + sx * self.data.shape[1] if sx >= 0. else ox
+
+ @property
+ def yMax(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy + sy * self.data.shape[0] if sy >= 0. else oy
+
+ def discard(self):
+ pass
+
+ def prepare(self):
+ pass
+
+ def render(self, matrix, isXLog, isYLog):
+ pass
+
+
+class GLPlotColormap(_GLPlotData2D):
+
+ _SHADERS = {
+ 'linear': {
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ attribute vec2 texCoords;
+ attribute vec2 position;
+
+ varying vec2 coords;
+
+ void main(void) {
+ coords = texCoords;
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ 'fragTransform': """
+ vec2 textureCoords(void) {
+ return coords;
+ }
+ """},
+
+ 'log': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ 'fragTransform': """
+ uniform bvec2 isLog;
+ uniform struct {
+ vec2 oneOverRange;
+ vec2 originOverRange;
+ } bounds;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds.oneOverRange - bounds.originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+ """},
+
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D data;
+ uniform struct {
+ sampler2D texture;
+ bool isLog;
+ float min;
+ float oneOverRange;
+ } cmap;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ %s
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ float value = texture2D(data, textureCoords()).r;
+ if (cmap.isLog) {
+ if (value > 0.) {
+ value = clamp(cmap.oneOverRange *
+ (oneOverLog10 * log(value) - cmap.min),
+ 0., 1.);
+ } else {
+ value = 0.;
+ }
+ } else { /*Linear mapping*/
+ value = clamp(cmap.oneOverRange * (value - cmap.min), 0., 1.);
+ }
+
+ gl_FragColor = texture2D(cmap.texture, vec2(value, 0.5));
+ gl_FragColor.a *= alpha;
+ }
+ """
+ }
+
+ _DATA_TEX_UNIT = 0
+ _CMAP_TEX_UNIT = 1
+
+ _INTERNAL_FORMATS = {
+ numpy.dtype(numpy.float32): gl.GL_R32F,
+ # Use normalized integer for unsigned int formats
+ numpy.dtype(numpy.uint16): gl.GL_R16,
+ numpy.dtype(numpy.uint8): gl.GL_R8,
+ }
+
+ _linearProgram = Program(_SHADERS['linear']['vertex'],
+ _SHADERS['fragment'] %
+ _SHADERS['linear']['fragTransform'],
+ attrib0='position')
+
+ _logProgram = Program(_SHADERS['log']['vertex'],
+ _SHADERS['fragment'] %
+ _SHADERS['log']['fragTransform'],
+ attrib0='position')
+
+ def __init__(self, data, origin, scale,
+ colormap, cmapIsLog=False, cmapRange=None,
+ alpha=1.0):
+ """Create a 2D colormap
+
+ :param data: The 2D scalar data array to display
+ :type data: numpy.ndarray with 2 dimensions (dtype=numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param str colormap: Name of the colormap to use
+ TODO: Accept a 1D scalar array as the colormap
+ :param bool cmapIsLog: If True, uses log10 of the data value
+ :param cmapRange: The range of colormap or None for autoscale colormap
+ For logarithmic colormap, the range is in the untransformed data
+ TODO: check consistency with matplotlib
+ :type cmapRange: (float, float) or None
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ """
+ assert data.dtype in self._INTERNAL_FORMATS
+
+ super(GLPlotColormap, self).__init__(data, origin, scale)
+ self.colormap = numpy.array(colormap, copy=False)
+ self.cmapIsLog = cmapIsLog
+ self._cmapRange = None # User-provided range info
+ self._cmapRangeCache = None # Store extra data for range
+ self.cmapRange = cmapRange # Update _cmapRange
+ self._alpha = numpy.clip(alpha, 0., 1.)
+
+ self._cmap_texture = None
+ self._texture = None
+ self._textureIsDirty = False
+
+ def discard(self):
+ if self._cmap_texture is not None:
+ self._cmap_texture.discard()
+ self._cmap_texture = None
+
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ @property
+ def cmapRange(self):
+ if self._cmapRange is None: # Auto-scale mode
+ if self._cmapRangeCache is None:
+ # Build data , positive ranges
+ result = min_max(self.data, min_positive=True)
+ min_ = result.minimum
+ minPos = result.min_positive
+ max_ = result.maximum
+ maxPos = max_ if max_ > 0. else 1.
+ if minPos is None:
+ minPos = maxPos
+ self._cmapRangeCache = {'range': (min_, max_),
+ 'pos': (minPos, maxPos)}
+
+ return self._cmapRangeCache['pos' if self.cmapIsLog else 'range']
+
+ else:
+ if not self.cmapIsLog:
+ return self._cmapRange # Return range as is
+ else:
+ if self._cmapRangeCache is None:
+ # Build a strictly positive range from cmapRange
+ min_, max_ = self._cmapRange
+ if min_ > 0. and max_ > 0.:
+ minPos, maxPos = min_, max_
+ else:
+ result = min_max(self.data, min_positive=True)
+ minPos = result.min_positive
+ dataMax = result.maximum
+ if max_ > 0.:
+ maxPos = max_
+ elif dataMax > 0.:
+ maxPos = dataMax
+ else:
+ maxPos = 1. # Arbitrary fallback
+ if minPos is None:
+ minPos = maxPos
+ self._cmapRangeCache = minPos, maxPos
+ return self._cmapRangeCache # Strictly positive range
+
+ @cmapRange.setter
+ def cmapRange(self, cmapRange):
+ self._cmapRangeCache = None
+ if cmapRange is None:
+ self._cmapRange = None
+ else:
+ assert len(cmapRange) == 2
+ assert cmapRange[0] <= cmapRange[1]
+ self._cmapRange = tuple(cmapRange)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def updateData(self, data):
+ assert data.dtype in self._INTERNAL_FORMATS
+ oldData = self.data
+ self.data = data
+
+ self._cmapRangeCache = None
+
+ if self._texture is not None:
+ if (self.data.shape != oldData.shape or
+ self.data.dtype != oldData.dtype):
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._cmap_texture is None:
+ # TODO share cmap texture accross Images
+ # put all cmaps in one texture
+ colormap = numpy.empty((16, 256, self.colormap.shape[1]),
+ dtype=self.colormap.dtype)
+ colormap[:] = self.colormap
+ format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB
+ self._cmap_texture = Texture(internalFormat=format_,
+ data=colormap,
+ format_=format_,
+ texUnit=self._CMAP_TEX_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+
+ if self._texture is None:
+ internalFormat = self._INTERNAL_FORMATS[self.data.dtype]
+
+ self._texture = Image(internalFormat,
+ self.data,
+ format_=gl.GL_RED,
+ texUnit=self._DATA_TEX_UNIT)
+ elif self._textureIsDirty:
+ self._textureIsDirty = True
+ self._texture.updateAll(format_=gl.GL_RED, data=self.data)
+
+ def _setCMap(self, prog):
+ dataMin, dataMax = self.cmapRange # If log, it is stricly positive
+
+ if self.data.dtype in (numpy.uint16, numpy.uint8):
+ # Using unsigned int as normalized integer in OpenGL
+ # So normalize range
+ maxInt = float(numpy.iinfo(self.data.dtype).max)
+ dataMin, dataMax = dataMin / maxInt, dataMax / maxInt
+
+ if self.cmapIsLog:
+ dataMin = math.log10(dataMin)
+ dataMax = math.log10(dataMax)
+
+ gl.glUniform1i(prog.uniforms['cmap.texture'],
+ self._cmap_texture.texUnit)
+ gl.glUniform1i(prog.uniforms['cmap.isLog'], self.cmapIsLog)
+ gl.glUniform1f(prog.uniforms['cmap.min'], dataMin)
+ if dataMax > dataMin:
+ oneOverRange = 1. / (dataMax - dataMin)
+ else:
+ oneOverRange = 0. # Fall-back
+ gl.glUniform1f(prog.uniforms['cmap.oneOverRange'], oneOverRange)
+
+ self._cmap_texture.bind()
+
+ def _renderLinear(self, matrix):
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+
+ mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale)
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._setCMap(prog)
+
+ self._texture.render(prog.attributes['position'],
+ prog.attributes['texCoords'],
+ self._DATA_TEX_UNIT)
+
+ def _renderLog10(self, matrix, isXLog, isYLog):
+ xMin, yMin = self.xMin, self.yMin
+ if ((isXLog and xMin < FLOAT32_MINPOS) or
+ (isYLog and yMin < FLOAT32_MINPOS)):
+ # Do not render images that are partly or totally <= 0
+ return
+
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+ mat = mat4Translate(ox, oy) * mat4Scale(*self.scale)
+ gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat)
+
+ gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1. / (ex - ox)
+ yOneOverRange = 1. / (ey - oy)
+ gl.glUniform2f(prog.uniforms['bounds.originOverRange'],
+ ox * xOneOverRange, oy * yOneOverRange)
+ gl.glUniform2f(prog.uniforms['bounds.oneOverRange'],
+ xOneOverRange, yOneOverRange)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._setCMap(prog)
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale")
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes['position']
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, matrix, isXLog, isYLog):
+ if any((isXLog, isYLog)):
+ self._renderLog10(matrix, isXLog, isYLog)
+ else:
+ self._renderLinear(matrix)
+
+ # Unbind colormap texture
+ gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit)
+ gl.glBindTexture(self._cmap_texture.target, 0)
+
+
+# image #######################################################################
+
+class GLPlotRGBAImage(_GLPlotData2D):
+
+ _SHADERS = {
+ 'linear': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a *= alpha;
+ }
+ """},
+
+ 'log': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform bvec2 isLog;
+ uniform struct {
+ vec2 oneOverRange;
+ vec2 originOverRange;
+ } bounds;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds.oneOverRange - bounds.originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, textureCoords());
+ gl_FragColor.a *= alpha;
+ }
+ """}
+ }
+
+ _DATA_TEX_UNIT = 0
+
+ _SUPPORTED_DTYPES = (numpy.dtype(numpy.float32),
+ numpy.dtype(numpy.uint8))
+
+ _linearProgram = Program(_SHADERS['linear']['vertex'],
+ _SHADERS['linear']['fragment'],
+ attrib0='position')
+
+ _logProgram = Program(_SHADERS['log']['vertex'],
+ _SHADERS['log']['fragment'],
+ attrib0='position')
+
+ def __init__(self, data, origin, scale, alpha):
+ """Create a 2D RGB(A) image from data
+
+ :param data: The 2D image data array to display
+ :type data: numpy.ndarray with 3 dimensions
+ (dtype=numpy.uint8 or numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ """
+ assert data.dtype in self._SUPPORTED_DTYPES
+ super(GLPlotRGBAImage, self).__init__(data, origin, scale)
+ self._texture = None
+ self._textureIsDirty = False
+ self._alpha = numpy.clip(alpha, 0., 1.)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def discard(self):
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ def updateData(self, data):
+ assert data.dtype in self._SUPPORTED_DTYPES
+ oldData = self.data
+ self.data = data
+
+ if self._texture is not None:
+ if self.data.shape != oldData.shape:
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._texture is None:
+ format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB
+
+ self._texture = Image(format_,
+ self.data,
+ format_=format_,
+ texUnit=self._DATA_TEX_UNIT)
+ elif self._textureIsDirty:
+ self._textureIsDirty = False
+
+ # We should check that internal format is the same
+ format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB
+ self._texture.updateAll(format_=format_, data=self.data)
+
+ def _renderLinear(self, matrix):
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+
+ mat = matrix * mat4Translate(*self.origin) * mat4Scale(*self.scale)
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._texture.render(prog.attributes['position'],
+ prog.attributes['texCoords'],
+ self._DATA_TEX_UNIT)
+
+ def _renderLog(self, matrix, isXLog, isYLog):
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+ mat = mat4Translate(ox, oy) * mat4Scale(*self.scale)
+ gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat)
+
+ gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1. / (ex - ox)
+ yOneOverRange = 1. / (ey - oy)
+ gl.glUniform2f(prog.uniforms['bounds.originOverRange'],
+ ox * xOneOverRange, oy * yOneOverRange)
+ gl.glUniform2f(prog.uniforms['bounds.oneOverRange'],
+ xOneOverRange, yOneOverRange)
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale")
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes['position']
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, matrix, isXLog, isYLog):
+ if any((isXLog, isYLog)):
+ self._renderLog(matrix, isXLog, isYLog)
+ else:
+ self._renderLinear(matrix)
diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py
new file mode 100644
index 0000000..3f473be
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLSupport.py
@@ -0,0 +1,192 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides convenient classes and functions for OpenGL rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import numpy
+
+from ...._glutils import gl
+
+
+def buildFillMaskIndices(nIndices):
+ if nIndices <= numpy.iinfo(numpy.uint16).max + 1:
+ dtype = numpy.uint16
+ else:
+ dtype = numpy.uint32
+
+ lastIndex = nIndices - 1
+ splitIndex = lastIndex // 2 + 1
+ indices = numpy.empty(nIndices, dtype=dtype)
+ indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype)
+ indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1,
+ dtype=dtype)
+ return indices
+
+
+class Shape2D(object):
+ _NO_HATCH = 0
+ _HATCH_STEP = 20
+
+ def __init__(self, points, fill='solid', stroke=True,
+ fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.),
+ strokeClosed=True):
+ self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
+ self.strokeClosed = strokeClosed
+
+ self._indices = buildFillMaskIndices(len(self.vertices))
+
+ tVertex = numpy.transpose(self.vertices)
+ xMin, xMax = min(tVertex[0]), max(tVertex[0])
+ yMin, yMax = min(tVertex[1]), max(tVertex[1])
+ self.bboxVertices = numpy.array(((xMin, yMin), (xMin, yMax),
+ (xMax, yMin), (xMax, yMax)),
+ dtype=numpy.float32)
+ self._xMin, self._xMax = xMin, xMax
+ self._yMin, self._yMax = yMin, yMax
+
+ self.fill = fill
+ self.fillColor = fillColor
+ self.stroke = stroke
+ self.strokeColor = strokeColor
+
+ @property
+ def xMin(self):
+ return self._xMin
+
+ @property
+ def xMax(self):
+ return self._xMax
+
+ @property
+ def yMin(self):
+ return self._yMin
+
+ @property
+ def yMax(self):
+ return self._yMax
+
+ def prepareFillMask(self, posAttrib):
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, self.vertices)
+
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawElements(gl.GL_TRIANGLE_STRIP, len(self._indices),
+ gl.GL_UNSIGNED_SHORT, self._indices)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ def renderFill(self, posAttrib):
+ self.prepareFillMask(posAttrib)
+
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, self.bboxVertices)
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices))
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+ def renderStroke(self, posAttrib):
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, self.vertices)
+ gl.glLineWidth(1)
+ drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP
+ gl.glDrawArrays(drawMode, 0, len(self.vertices))
+
+ def render(self, posAttrib, colorUnif, hatchStepUnif):
+ assert self.fill in ['hatch', 'solid', None]
+ if self.fill is not None:
+ gl.glUniform4f(colorUnif, *self.fillColor)
+ step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH
+ gl.glUniform1i(hatchStepUnif, step)
+ self.renderFill(posAttrib)
+
+ if self.stroke:
+ gl.glUniform4f(colorUnif, *self.strokeColor)
+ gl.glUniform1i(hatchStepUnif, self._NO_HATCH)
+ self.renderStroke(posAttrib)
+
+
+# matrix ######################################################################
+
+def mat4Ortho(left, right, bottom, top, near, far):
+ """Orthographic projection matrix (row-major)"""
+ return numpy.matrix((
+ (2./(right - left), 0., 0., -(right+left)/float(right-left)),
+ (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)),
+ (0., 0., -2./(far-near), -(far+near)/float(far-near)),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Translate(x=0., y=0., z=0.):
+ """Translation matrix (row-major)"""
+ return numpy.matrix((
+ (1., 0., 0., x),
+ (0., 1., 0., y),
+ (0., 0., 1., z),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Scale(sx=1., sy=1., sz=1.):
+ """Scale matrix (row-major)"""
+ return numpy.matrix((
+ (sx, 0., 0., 0.),
+ (0., sy, 0., 0.),
+ (0., 0., sz, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Identity():
+ """Identity matrix"""
+ return numpy.matrix((
+ (1., 0., 0., 0.),
+ (0., 1., 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py
new file mode 100644
index 0000000..495882c
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLText.py
@@ -0,0 +1,222 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides minimalistic text support for OpenGL.
+It provides Latin-1 (ISO8859-1) characters for one monospace font at one size.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import numpy
+
+from ...._glutils import font, gl, getGLContext, Program, Texture
+from .GLSupport import mat4Translate
+
+
+# TODO: Font should be configurable by the main program: using mpl.rcParams?
+
+
+# Text2D ######################################################################
+
+LEFT, CENTER, RIGHT = 'left', 'center', 'right'
+TOP, BASELINE, BOTTOM = 'top', 'baseline', 'bottom'
+ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270
+
+
+class Text2D(object):
+
+ _SHADERS = {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ vCoords = texCoords;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D texText;
+ uniform vec4 color;
+ uniform vec4 bgColor;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r);
+ }
+ """
+ }
+
+ _TEX_COORDS = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
+ dtype=numpy.float32).ravel()
+
+ _program = Program(_SHADERS['vertex'],
+ _SHADERS['fragment'],
+ attrib0='position')
+
+ _textures = {}
+
+ _rasterTextCache = {}
+ """Internal cache storing already rasterized text"""
+ # TODO limit cache size and discard least recent used
+
+ def __init__(self, text, x=0, y=0,
+ color=(0., 0., 0., 1.),
+ bgColor=None,
+ align=LEFT, valign=BASELINE,
+ rotate=0):
+ self._vertices = None
+ self._text = text
+ self.x = x
+ self.y = y
+ self.color = color
+ self.bgColor = bgColor
+
+ if align not in (LEFT, CENTER, RIGHT):
+ raise ValueError(
+ "Horizontal alignment not supported: {0}".format(align))
+ self._align = align
+
+ if valign not in (TOP, CENTER, BASELINE, BOTTOM):
+ raise ValueError(
+ "Vertical alignment not supported: {0}".format(valign))
+ self._valign = valign
+
+ self._rotate = numpy.radians(rotate)
+
+ @classmethod
+ def _getTexture(cls, text):
+ key = getGLContext(), text
+ if key not in cls._textures:
+ image, offset = font.rasterText(text,
+ font.getDefaultFontFamily())
+ cls._textures[key] = (Texture(gl.GL_RED,
+ data=image,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE)),
+ offset)
+
+ return cls._textures[key]
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def size(self): # TODO very poor implementation
+ image, offset = font.rasterText(self.text,
+ font.getDefaultFontFamily())
+ return image.shape[1], image.shape[0]
+
+ def getVertices(self, offset, shape):
+ height, width = shape
+
+ if self._align == LEFT:
+ xOrig = 0
+ elif self._align == RIGHT:
+ xOrig = - width
+ else: # CENTER
+ xOrig = - width // 2
+
+ if self._valign == BASELINE:
+ yOrig = - offset
+ elif self._valign == TOP:
+ yOrig = 0
+ elif self._valign == BOTTOM:
+ yOrig = - height
+ else: # CENTER
+ yOrig = - height // 2
+
+ vertices = numpy.array((
+ (xOrig, yOrig),
+ (xOrig + width, yOrig),
+ (xOrig, yOrig + height),
+ (xOrig + width, yOrig + height)), dtype=numpy.float32)
+
+ cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate)
+ vertices = numpy.ascontiguousarray(numpy.transpose(numpy.array((
+ cos * vertices[:, 0] - sin * vertices[:, 1],
+ sin * vertices[:, 0] + cos * vertices[:, 1]),
+ dtype=numpy.float32)))
+
+ return vertices
+
+ def render(self, matrix):
+ if not self.text:
+ return
+
+ prog = self._program
+ prog.use()
+
+ texUnit = 0
+ texture, offset = self._getTexture(self.text)
+
+ gl.glUniform1i(prog.uniforms['texText'], texUnit)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ matrix * mat4Translate(int(self.x), int(self.y)))
+
+ gl.glUniform4f(prog.uniforms['color'], *self.color)
+ if self.bgColor is not None:
+ bgColor = self.bgColor
+ else:
+ bgColor = self.color[0], self.color[1], self.color[2], 0.
+ gl.glUniform4f(prog.uniforms['bgColor'], *bgColor)
+
+ vertices = self.getVertices(offset, texture.shape)
+
+ posAttrib = prog.attributes['position']
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ vertices)
+
+ texAttrib = prog.attributes['texCoords']
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(texAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._TEX_COORDS)
+
+ with texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py
new file mode 100644
index 0000000..25dd9f1
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLTexture.py
@@ -0,0 +1,239 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides classes wrapping OpenGL texture."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from ...._glutils import gl, Texture, numpyToGLType
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _checkTexture2D(internalFormat, shape,
+ format_=None, type_=gl.GL_FLOAT, border=0):
+ """Check if texture size with provided parameters is supported
+
+ :rtype: bool
+ """
+ height, width = shape
+ gl.glTexImage2D(gl.GL_PROXY_TEXTURE_2D, 0, internalFormat,
+ width, height, border,
+ format_ or internalFormat,
+ type_, c_void_p(0))
+ width = gl.glGetTexLevelParameteriv(
+ gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH)
+ return bool(width)
+
+
+MIN_TEXTURE_SIZE = 64
+
+
+def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA,
+ format_=None,
+ type_=gl.GL_FLOAT,
+ border=0):
+ """Returns a supported size for a corresponding square texture
+
+ :returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal)
+ :rtype: int
+ """
+ # Is this useful?
+ maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE)
+ while maxTexSize > MIN_TEXTURE_SIZE and \
+ not _checkTexture2D(internalFormat, (maxTexSize, maxTexSize),
+ format_, type_, border):
+ maxTexSize //= 2
+ return max(MIN_TEXTURE_SIZE, maxTexSize)
+
+
+class Image(object):
+ """Image of any size eventually using multiple textures or larger texture
+ """
+
+ _WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE)
+ _MIN_FILTER = gl.GL_NEAREST
+ _MAG_FILTER = gl.GL_NEAREST
+
+ def __init__(self, internalFormat, data, format_=None, texUnit=0):
+ self.internalFormat = internalFormat
+ self.height, self.width = data.shape[0:2]
+ type_ = numpyToGLType(data.dtype)
+
+ if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_):
+ texture = Texture(internalFormat,
+ data,
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ vertices = numpy.array((
+ (0., 0., 0., 0.),
+ (self.width, 0., 1., 0.),
+ (0., self.height, 0., 1.),
+ (self.width, self.height, 1., 1.)), dtype=numpy.float32)
+ self.tiles = ((texture, vertices,
+ {'xOrigData': 0, 'yOrigData': 0,
+ 'wData': self.width, 'hData': self.height}),)
+
+ else:
+ # Handle dimension too large: make tiles
+ maxTexSize = _getMaxSquareTexture2DSize(internalFormat,
+ format_, type_)
+
+ nCols = (self.width+maxTexSize-1) // maxTexSize
+ colWidths = [self.width // nCols] * nCols
+ colWidths[-1] += self.width % nCols
+
+ nRows = (self.height+maxTexSize-1) // maxTexSize
+ rowHeights = [self.height//nRows] * nRows
+ rowHeights[-1] += self.height % nRows
+
+ tiles = []
+ yOrig = 0
+ for hData in rowHeights:
+ xOrig = 0
+ for wData in colWidths:
+ if (hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE) \
+ and not _checkTexture2D(internalFormat,
+ (hData, wData),
+ format_,
+ type_):
+ # Ensure texture size is at least MIN_TEXTURE_SIZE
+ tH = max(hData, MIN_TEXTURE_SIZE)
+ tW = max(wData, MIN_TEXTURE_SIZE)
+
+ uMax, vMax = float(wData)/tW, float(hData)/tH
+
+ # TODO issue with type_ and alignment
+ texture = Texture(internalFormat,
+ data=None,
+ format_=format_,
+ shape=(tH, tW),
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ # TODO handle unpack
+ texture.update(format_,
+ data[yOrig:yOrig+hData,
+ xOrig:xOrig+wData])
+ # texture.update(format_, type_, data,
+ # width=wData, height=hData,
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ else:
+ uMax, vMax = 1, 1
+ # TODO issue with type_ and unpacking tiles
+ # TODO idea to handle unpack: use array strides
+ # As it is now, it will make a copy
+ texture = Texture(internalFormat,
+ data[yOrig:yOrig+hData,
+ xOrig:xOrig+wData],
+ format_,
+ shape=(hData, wData),
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ # TODO
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ vertices = numpy.array((
+ (xOrig, yOrig, 0., 0.),
+ (xOrig + wData, yOrig, uMax, 0.),
+ (xOrig, yOrig + hData, 0., vMax),
+ (xOrig + wData, yOrig + hData, uMax, vMax)),
+ dtype=numpy.float32)
+ tiles.append((texture, vertices,
+ {'xOrigData': xOrig, 'yOrigData': yOrig,
+ 'wData': wData, 'hData': hData}))
+ xOrig += wData
+ yOrig += hData
+ self.tiles = tuple(tiles)
+
+ def discard(self):
+ for texture, vertices, _ in self.tiles:
+ texture.discard()
+ del self.tiles
+
+ def updateAll(self, format_, data, texUnit=0):
+ if not hasattr(self, 'tiles'):
+ raise RuntimeError("No texture, discard has already been called")
+
+ assert data.shape[:2] == (self.height, self.width)
+ if len(self.tiles) == 1:
+ self.tiles[0][0].update(format_, data, texUnit=texUnit)
+ else:
+ for texture, _, info in self.tiles:
+ yOrig, xOrig = info['yOrigData'], info['xOrigData']
+ height, width = info['hData'], info['wData']
+ texture.update(format_,
+ data[yOrig:yOrig+height, xOrig:xOrig+width],
+ texUnit=texUnit)
+ # TODO check
+ # width=info['wData'], height=info['hData'],
+ # texUnit=texUnit, unpackAlign=unpackAlign,
+ # unpackRowLength=self.width,
+ # unpackSkipPixels=info['xOrigData'],
+ # unpackSkipRows=info['yOrigData'])
+
+ def render(self, posAttrib, texAttrib, texUnit=0):
+ try:
+ tiles = self.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+
+ for texture, vertices, _ in tiles:
+ texture.bind(texUnit)
+
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ texCoordsPtr = c_void_p(vertices.ctypes.data +
+ 2 * vertices.itemsize)
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(texAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, texCoordsPtr)
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
diff --git a/silx/gui/plot/backends/glutils/PlotImageFile.py b/silx/gui/plot/backends/glutils/PlotImageFile.py
new file mode 100644
index 0000000..e4ebe24
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/PlotImageFile.py
@@ -0,0 +1,149 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Function to save an image to a file."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import base64
+import struct
+import sys
+import zlib
+
+
+# Image writer ################################################################
+
+def convertRGBDataToPNG(data):
+ """Convert a RGB bitmap to PNG.
+
+ It only supports RGB bitmap with one byte per channel stored as a 3D array.
+ See `Definitive Guide <http://www.libpng.org/pub/png/book/>`_ and
+ `Specification <http://www.libpng.org/pub/png/spec/1.2/>`_ for details.
+
+ :param data: A 3D array (h, w, rgb) storing an RGB image
+ :type data: numpy.ndarray of unsigned bytes
+ :returns: The PNG encoded data
+ :rtype: bytes
+ """
+ height, width = data.shape[0], data.shape[1]
+ depth = 8 # 8 bit per channel
+ colorType = 2 # 'truecolor' = RGB
+ interlace = 0 # No
+
+ IHDRdata = struct.pack(">ccccIIBBBBB", b'I', b'H', b'D', b'R',
+ width, height, depth, colorType,
+ 0, 0, interlace)
+
+ # Add filter 'None' before each scanline
+ preparedData = b'\x00' + b'\x00'.join(line.tostring() for line in data)
+ compressedData = zlib.compress(preparedData, 8)
+
+ IDATdata = struct.pack("cccc", b'I', b'D', b'A', b'T')
+ IDATdata += compressedData
+
+ return b''.join([
+ b'\x89PNG\r\n\x1a\n', # PNG signature
+ # IHDR chunk: Image Header
+ struct.pack(">I", 13), # length
+ IHDRdata,
+ struct.pack(">I", zlib.crc32(IHDRdata) & 0xffffffff), # CRC
+ # IDAT chunk: Payload
+ struct.pack(">I", len(compressedData)),
+ IDATdata,
+ struct.pack(">I", zlib.crc32(IDATdata) & 0xffffffff), # CRC
+ b'\x00\x00\x00\x00IEND\xaeB`\x82' # IEND chunk: footer
+ ])
+
+
+def saveImageToFile(data, fileNameOrObj, fileFormat):
+ """Save a RGB image to a file.
+
+ :param data: A 3D array (h, w, 3) storing an RGB image.
+ :type data: numpy.ndarray with of unsigned bytes.
+ :param fileNameOrObj: Filename or object to use to write the image.
+ :type fileNameOrObj: A str or a 'file-like' object with a 'write' method.
+ :param str fileFormat: The type of the file in: 'png', 'ppm', 'svg', 'tiff'.
+ """
+ assert len(data.shape) == 3
+ assert data.shape[2] == 3
+ assert fileFormat in ('png', 'ppm', 'svg', 'tiff')
+
+ if not hasattr(fileNameOrObj, 'write'):
+ if sys.version < "3.0":
+ fileObj = open(fileNameOrObj, "wb")
+ else:
+ fileObj = open(fileNameOrObj, "w", newline='')
+ else: # Use as a file-like object
+ fileObj = fileNameOrObj
+
+ if fileFormat == 'svg':
+ height, width = data.shape[:2]
+ base64Data = base64.b64encode(convertRGBDataToPNG(data))
+
+ fileObj.write(
+ '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
+ fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n')
+ fileObj.write(
+ ' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n')
+ fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n')
+ fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n')
+ fileObj.write(' version="1.1"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d">\n' % height)
+ fileObj.write(' <image xlink:href="data:image/png;base64,')
+ fileObj.write(base64Data.decode('ascii'))
+ fileObj.write('"\n')
+ fileObj.write(' x="0"\n')
+ fileObj.write(' y="0"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d"\n' % height)
+ fileObj.write(' id="image" />\n')
+ fileObj.write('</svg>')
+
+ elif fileFormat == 'ppm':
+ height, width = data.shape[:2]
+
+ fileObj.write('P6\n')
+ fileObj.write('%d %d\n' % (width, height))
+ fileObj.write('255\n')
+ fileObj.write(data.tostring())
+
+ elif fileFormat == 'png':
+ fileObj.write(convertRGBDataToPNG(data))
+
+ elif fileFormat == 'tiff':
+ if fileObj == fileNameOrObj:
+ raise NotImplementedError(
+ 'Save TIFF to a file-like object not implemented')
+
+ from silx.third_party.TiffIO import TiffIO
+
+ tif = TiffIO(fileNameOrObj, mode='wb+')
+ tif.writeImage(data, info={'Title': 'PyMCA GL Snapshot'})
+
+ if fileObj != fileNameOrObj:
+ fileObj.close()
diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py
new file mode 100644
index 0000000..771de39
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/__init__.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides convenient classes for the OpenGL rendering backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import logging
+
+
+_logger = logging.getLogger(__name__)
+
+
+from .GLPlotCurve import * # noqa
+from .GLPlotFrame import * # noqa
+from .GLPlotImage import * # noqa
+from .GLSupport import * # noqa
+from .GLText import * # noqa
+from .GLTexture import * # noqa
diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py
new file mode 100644
index 0000000..b16fe40
--- /dev/null
+++ b/silx/gui/plot/items/__init__.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides classes that describes :class:`.Plot` content.
+
+Instances of those classes are returned by :class:`.Plot` methods that give
+access to its content such as :meth:`.Plot.getCurve`, :meth:`.Plot.getImage`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
+ SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
+ AlphaMixIn, LineMixIn) # noqa
+from .curve import Curve # noqa
+from .histogram import Histogram # noqa
+from .image import ImageBase, ImageData, ImageRgba # noqa
+from .shape import Shape # noqa
+from .scatter import Scatter # noqa
+from .marker import Marker, XMarker, YMarker # noqa
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
new file mode 100644
index 0000000..72bfd9a
--- /dev/null
+++ b/silx/gui/plot/items/core.py
@@ -0,0 +1,839 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+from copy import deepcopy
+import logging
+import weakref
+import numpy
+from silx.third_party import six
+
+from .. import Colors
+
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Item(object):
+ """Description of an item of the plot"""
+
+ _DEFAULT_Z_LAYER = 0
+ """Default layer for overlay rendering"""
+
+ _DEFAULT_LEGEND = ''
+ """Default legend of items"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state of items"""
+
+ def __init__(self):
+ self._dirty = True
+ self._plotRef = None
+ self._visible = True
+ self._legend = self._DEFAULT_LEGEND
+ self._selectable = self._DEFAULT_SELECTABLE
+ self._z = self._DEFAULT_Z_LAYER
+ self._info = None
+ self._xlabel = None
+ self._ylabel = None
+
+ self._backendRenderer = None
+
+ def getPlot(self):
+ """Returns Plot this item belongs to.
+
+ :rtype: Plot or None
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def _setPlot(self, plot):
+ """Set the plot this item belongs to.
+
+ WARNING: This should only be called from the Plot.
+
+ :param Plot plot: The Plot instance.
+ """
+ if plot is not None and self._plotRef is not None:
+ raise RuntimeError('Trying to add a node at two places.')
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self._updated()
+
+ def getBounds(self): # TODO return a Bounds object rather than a tuple
+ """Returns the bounding box of this item in data coordinates
+
+ :returns: (xmin, xmax, ymin, ymax) or None
+ :rtype: 4-tuple of float or None
+ """
+ return self._getBounds()
+
+ def _getBounds(self):
+ """:meth:`getBounds` implementation to override by sub-class"""
+ return None
+
+ def isVisible(self):
+ """True if item is visible, False otherwise
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visible = bool(visible)
+ if visible != self._visible:
+ self._visible = visible
+ # When visibility has changed, always mark as dirty
+ self._updated(checkVisibility=False)
+
+ def isOverlay(self):
+ """Return true if item is drawn as an overlay.
+
+ :rtype: bool
+ """
+ return False
+
+ def getLegend(self):
+ """Returns the legend of this item (str)"""
+ return self._legend
+
+ def _setLegend(self, legend):
+ """Set the legend.
+
+ This is private as it is used by the plot as an identifier
+
+ :param str legend: Item legend
+ """
+ legend = str(legend) if legend is not None else self._DEFAULT_LEGEND
+ self._legend = legend
+
+ def isSelectable(self):
+ """Returns true if item is selectable (bool)"""
+ return self._selectable
+
+ def _setSelectable(self, selectable): # TODO support update
+ """Set whether item is selectable or not.
+
+ This is private for now as change is not handled.
+
+ :param bool selectable: True to make item selectable
+ """
+ self._selectable = bool(selectable)
+
+ def getZValue(self):
+ """Returns the layer on which to draw this item (int)"""
+ return self._z
+
+ def setZValue(self, z):
+ z = int(z) if z is not None else self._DEFAULT_Z_LAYER
+ if z != self._z:
+ self._z = z
+ self._updated()
+
+ def getInfo(self, copy=True):
+ """Returns the info associated to this item
+
+ :param bool copy: True to get a deepcopy, False otherwise.
+ """
+ return deepcopy(self._info) if copy else self._info
+
+ def setInfo(self, info, copy=True):
+ if copy:
+ info = deepcopy(info)
+ self._info = info
+
+ def _updated(self, checkVisibility=True):
+ """Mark the item as dirty (i.e., needing update).
+
+ This also triggers Plot.replot.
+
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ if not checkVisibility or self.isVisible():
+ if not self._dirty:
+ self._dirty = True
+ # TODO: send event instead of explicit call
+ plot = self.getPlot()
+ if plot is not None:
+ plot._itemRequiresUpdate(self)
+
+ def _update(self, backend):
+ """Called by Plot to update the backend for this item.
+
+ This is meant to be called asynchronously from _updated.
+ This optimizes the number of call to _update.
+
+ :param backend: The backend to update
+ """
+ if self._dirty:
+ # Remove previous renderer from backend if any
+ self._removeBackendRenderer(backend)
+
+ # If not visible, do not add renderer to backend
+ if self.isVisible():
+ self._backendRenderer = self._addBackendRenderer(backend)
+
+ self._dirty = False
+
+ def _addBackendRenderer(self, backend):
+ """Override in subclass to add specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ :return: The renderer handle to store or None if no renderer in backend
+ """
+ return None
+
+ def _removeBackendRenderer(self, backend):
+ """Override in subclass to remove specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ """
+ if self._backendRenderer is not None:
+ backend.remove(self._backendRenderer)
+ self._backendRenderer = None
+
+
+# Mix-in classes ##############################################################
+
+class LabelsMixIn(object):
+ """Mix-in class for items with x and y labels
+
+ Setters are private, otherwise it needs to check the plot
+ current active curve and access the internal current labels.
+ """
+
+ def __init__(self):
+ self._xlabel = None
+ self._ylabel = None
+
+ def getXLabel(self):
+ """Return the X axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._xlabel
+
+ def _setXLabel(self, label):
+ """Set the X axis label associated with this curve
+
+ :param str label: The X axis label
+ """
+ self._xlabel = str(label)
+
+ def getYLabel(self):
+ """Return the Y axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._ylabel
+
+ def _setYLabel(self, label):
+ """Set the Y axis label associated with this curve
+
+ :param str label: The Y axis label
+ """
+ self._ylabel = str(label)
+
+
+class DraggableMixIn(object):
+ """Mix-in class for draggable items"""
+
+ def __init__(self):
+ self._draggable = False
+
+ def isDraggable(self):
+ """Returns true if image is draggable
+
+ :rtype: bool
+ """
+ return self._draggable
+
+ def _setDraggable(self, draggable): # TODO support update
+ """Set if image is draggable or not.
+
+ This is private for not as it does not support update.
+
+ :param bool draggable:
+ """
+ self._draggable = bool(draggable)
+
+
+class ColormapMixIn(object):
+ """Mix-in class for items with colormap"""
+
+ _DEFAULT_COLORMAP = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ """Default colormap of the item"""
+
+ def __init__(self):
+ self._colormap = self._DEFAULT_COLORMAP
+
+ def getColormap(self):
+ """Return the used colormap"""
+ return self._colormap.copy()
+
+ def setColormap(self, colormap):
+ """Set the colormap of this image
+
+ :param dict colormap: colormap description
+ """
+ self._colormap = colormap.copy()
+ # TODO colormap comparison + colormap object and events on modification
+ self._updated()
+
+
+class SymbolMixIn(object):
+ """Mix-in class for items with symbol type"""
+
+ _DEFAULT_SYMBOL = ''
+ """Default marker of the item"""
+
+ _DEFAULT_SYMBOL_SIZE = 6.0
+ """Default marker size of the item"""
+
+ def __init__(self):
+ self._symbol = self._DEFAULT_SYMBOL
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: str
+ """
+ return self._symbol
+
+ def setSymbol(self, symbol):
+ """Set the marker type
+
+ See :meth:`getSymbol`.
+
+ :param str symbol: Marker type
+ """
+ assert symbol in ('o', '.', ',', '+', 'x', 'd', 's', '', None)
+ if symbol is None:
+ symbol = self._DEFAULT_SYMBOL
+ if symbol != self._symbol:
+ self._symbol = symbol
+ self._updated()
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: float
+ """
+ return self._symbol_size
+
+ def setSymbolSize(self, size):
+ """Set the point marker size in points.
+
+ See :meth:`getSymbolSize`.
+
+ :param str symbol: Marker type
+ """
+ if size is None:
+ size = self._DEFAULT_SYMBOL_SIZE
+ if size != self._symbol_size:
+ self._symbol_size = size
+ self._updated()
+
+
+class LineMixIn(object):
+ """Mix-in class for item with line"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style"""
+
+ def __init__(self):
+ self._linewidth = self._DEFAULT_LINEWIDTH
+ self._linestyle = self._DEFAULT_LINESTYLE
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels (int)"""
+ return self._linewidth
+
+ def setLineWidth(self, width):
+ """Set the width in pixel of the curve line
+
+ See :meth:`getLineWidth`.
+
+ :param float width: Width in pixels
+ """
+ width = float(width)
+ if width != self._linewidth:
+ self._linewidth = width
+ self._updated()
+
+ def getLineStyle(self):
+ """Return the type of the line
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :rtype: str
+ """
+ return self._linestyle
+
+ def setLineStyle(self, style):
+ """Set the style of the curve line.
+
+ See :meth:`getLineStyle`.
+
+ :param str style: Line style
+ """
+ style = str(style)
+ assert style in ('', ' ', '-', '--', '-.', ':', None)
+ if style is None:
+ style = self._DEFAULT_LINESTYLE
+ if style != self._linestyle:
+ self._linestyle = style
+ self._updated()
+
+
+class ColorMixIn(object):
+ """Mix-in class for item with color"""
+
+ _DEFAULT_COLOR = (0., 0., 0., 1.)
+ """Default color of the item"""
+
+ def __init__(self):
+ self._color = self._DEFAULT_COLOR
+
+ def getColor(self):
+ """Returns the RGBA color of the item
+
+ :rtype: 4-tuple of float in [0, 1]
+ """
+ return self._color
+
+ def setColor(self, color, copy=True):
+ """Set item color
+
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if isinstance(color, six.string_types):
+ color = Colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = Colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._color = color
+ self._updated()
+
+
+class YAxisMixIn(object):
+ """Mix-in class for item with yaxis"""
+
+ _DEFAULT_YAXIS = 'left'
+ """Default Y axis the item belongs to"""
+
+ def __init__(self):
+ self._yaxis = self._DEFAULT_YAXIS
+
+ def getYAxis(self):
+ """Returns the Y axis this curve belongs to.
+
+ Either 'left' or 'right'.
+
+ :rtype: str
+ """
+ return self._yaxis
+
+ def setYAxis(self, yaxis):
+ """Set the Y axis this curve belongs to.
+
+ :param str yaxis: 'left' or 'right'
+ """
+ yaxis = str(yaxis)
+ assert yaxis in ('left', 'right')
+ if yaxis != self._yaxis:
+ self._yaxis = yaxis
+ self._updated()
+
+
+class FillMixIn(object):
+ """Mix-in class for item with fill"""
+
+ def __init__(self):
+ self._fill = False
+
+ def isFill(self):
+ """Returns whether the item is filled or not.
+
+ :rtype: bool
+ """
+ return self._fill
+
+ def setFill(self, fill):
+ """Set whether to fill the item or not.
+
+ :param bool fill:
+ """
+ fill = bool(fill)
+ if fill != self._fill:
+ self._fill = fill
+ self._updated()
+
+
+class AlphaMixIn(object):
+ """Mix-in class for item with opacity"""
+
+ def __init__(self):
+ self._alpha = 1.
+
+ def getAlpha(self):
+ """Returns the opacity of the item
+
+ :rtype: float in [0, 1.]
+ """
+ return self._alpha
+
+ def setAlpha(self, alpha):
+ """Set the opacity of the item
+
+ .. note::
+
+ If the colormap already has some transparency, this alpha
+ adds additional transparency. The alpha channel of the colormap
+ is multiplied by this value.
+
+ :param alpha: Opacity of the item, between 0 (full transparency)
+ and 1. (full opacity)
+ :type alpha: float
+ """
+ alpha = float(alpha)
+ alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range
+ if alpha != self._alpha:
+ self._alpha = alpha
+ self._updated()
+
+
+class Points(Item, SymbolMixIn, AlphaMixIn):
+ """Base class for :class:`Curve` and :class:`Scatter`"""
+ # note: _logFilterData must be overloaded if you overload
+ # getData to change its signature
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for points,
+ on top of images."""
+
+ def __init__(self):
+ Item.__init__(self)
+ SymbolMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ self._x = ()
+ self._y = ()
+ self._xerror = None
+ self._yerror = None
+
+ # Store filtered data for x > 0 and/or y > 0
+ self._filteredCache = {}
+ self._clippedCache = {}
+
+ # Store bounds depending on axes filtering >0:
+ # key is (isXPositiveFilter, isYPositiveFilter)
+ self._boundsCache = {}
+
+ @staticmethod
+ def _logFilterError(value, error):
+ """Filter/convert error values if they go <= 0.
+
+ Replace error leading to negative values by nan
+
+ :param numpy.ndarray value: 1D array of values
+ :param numpy.ndarray error:
+ Array of errors: scalar, N, Nx1 or 2xN or None.
+ :return: Filtered error so error bars are never negative
+ """
+ if error is not None:
+ # Convert Nx1 to N
+ if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1:
+ error = numpy.ravel(error)
+
+ # Supports error being scalar, N or 2xN array
+ errorClipped = (value - numpy.atleast_2d(error)[0]) <= 0
+
+ if numpy.any(errorClipped): # Need filtering
+
+ # expand errorbars to 2xN
+ if error.size == 1: # Scalar
+ error = numpy.full(
+ (2, len(value)), error, dtype=numpy.float)
+
+ elif error.ndim == 1: # N array
+ newError = numpy.empty((2, len(value)),
+ dtype=numpy.float)
+ newError[0, :] = error
+ newError[1, :] = error
+ error = newError
+
+ elif error.size == 2 * len(value): # 2xN array
+ error = numpy.array(
+ error, copy=True, dtype=numpy.float)
+
+ else:
+ _logger.error("Unhandled error array")
+ return error
+
+ error[0, errorClipped] = numpy.nan
+
+ return error
+
+ def _getClippingBoolArray(self, xPositive, yPositive):
+ """Compute a boolean array to filter out points with negative
+ coordinates on log axes.
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :rtype: boolean numpy.ndarray
+ """
+ assert xPositive or yPositive
+ if (xPositive, yPositive) not in self._clippedCache:
+ x = self.getXData(copy=False)
+ y = self.getYData(copy=False)
+ xclipped = (x <= 0) if xPositive else False
+ yclipped = (y <= 0) if yPositive else False
+ self._clippedCache[(xPositive, yPositive)] = \
+ numpy.logical_or(xclipped, yclipped)
+ return self._clippedCache[(xPositive, yPositive)]
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filter arrays or unchanged object if filtering not needed
+ :rtype: (x, y, xerror, yerror)
+ """
+ x = self.getXData(copy=False)
+ y = self.getYData(copy=False)
+ xerror = self.getXErrorData(copy=False)
+ yerror = self.getYErrorData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ x = numpy.array(x, copy=True, dtype=numpy.float)
+ x[clipped] = numpy.nan
+ y = numpy.array(y, copy=True, dtype=numpy.float)
+ y[clipped] = numpy.nan
+
+ if xPositive and xerror is not None:
+ xerror = self._logFilterError(x, xerror)
+
+ if yPositive and yerror is not None:
+ yerror = self._logFilterError(y, yerror)
+
+ return x, y, xerror, yerror
+
+ def _getBounds(self):
+ if self.getXData(copy=False).size == 0: # Empty data
+ return None
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.isXAxisLogarithmic()
+ yPositive = plot.isYAxisLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ # TODO bounds do not take error bars into account
+ if (xPositive, yPositive) not in self._boundsCache:
+ # use the getData class method because instance method can be
+ # overloaded to return additional arrays
+ data = Points.getData(self, copy=False,
+ displayed=True)
+ if len(data) == 5:
+ # hack to avoid duplicating caching mechanism in Scatter
+ # (happens when cached data is used, caching done using
+ # Scatter._logFilterData)
+ x, y, xerror, yerror = data[0], data[1], data[3], data[4]
+ else:
+ x, y, xerror, yerror = data
+
+ self._boundsCache[(xPositive, yPositive)] = (
+ numpy.nanmin(x),
+ numpy.nanmax(x),
+ numpy.nanmin(y),
+ numpy.nanmax(y)
+ )
+ return self._boundsCache[(xPositive, yPositive)]
+
+ def _getCachedData(self):
+ """Return cached filtered data if applicable,
+ i.e. if any axis is in log scale.
+ Return None if caching is not applicable."""
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.isXAxisLogarithmic()
+ yPositive = plot.isYAxisLogarithmic()
+ if xPositive or yPositive:
+ # At least one axis has log scale, filter data
+ if (xPositive, yPositive) not in self._filteredCache:
+ self._filteredCache[(xPositive, yPositive)] = \
+ self._logFilterData(xPositive, yPositive)
+ return self._filteredCache[(xPositive, yPositive)]
+ return None
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y values of the curve points and xerror, yerror
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, xerror, yerror)
+ :rtype: 4-tuple of numpy.ndarray
+ """
+ if displayed: # filter data according to plot state
+ cached_data = self._getCachedData()
+ if cached_data is not None:
+ return cached_data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ def getXData(self, copy=True):
+ """Returns the x coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the y coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._y, copy=copy)
+
+ def getXErrorData(self, copy=True):
+ """Returns the x error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray or None
+ """
+ if self._xerror is None:
+ return None
+ else:
+ return numpy.array(self._xerror, copy=copy)
+
+ def getYErrorData(self, copy=True):
+ """Returns the y error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray or None
+ """
+ if self._yerror is None:
+ return None
+ else:
+ return numpy.array(self._yerror, copy=copy)
+
+ def setData(self, x, y, xerror=None, yerror=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = numpy.array(x, copy=copy)
+ y = numpy.array(y, copy=copy)
+ assert len(x) == len(y)
+ assert x.ndim == y.ndim == 1
+
+ if xerror is not None:
+ xerror = numpy.array(xerror, copy=copy)
+ if yerror is not None:
+ yerror = numpy.array(yerror, copy=copy)
+ # TODO checks on xerror, yerror
+ self._x, self._y = x, y
+ self._xerror, self._yerror = xerror, yerror
+
+ self._boundsCache = {} # Reset cached bounds
+ self._filteredCache = {} # Reset cached filtered data
+ self._clippedCache = {} # Reset cached clipped bool array
+
+ self._updated()
+ # TODO hackish data range implementation
+ if self.isVisible():
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py
new file mode 100644
index 0000000..d25ae00
--- /dev/null
+++ b/silx/gui/plot/items/curve.py
@@ -0,0 +1,192 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Curve` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+import logging
+
+import numpy
+
+from .. import Colors
+from .core import (Points, LabelsMixIn, SymbolMixIn,
+ ColorMixIn, YAxisMixIn, FillMixIn, LineMixIn)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Curve(Points, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, LineMixIn):
+ """Description of a curve"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for curves"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for curves"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_COLOR = (0, 0, 0, 255)
+ """Default highlight color of the item"""
+
+ def __init__(self):
+ Points.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LabelsMixIn.__init__(self)
+ LineMixIn.__init__(self)
+
+ self._highlightColor = self._DEFAULT_HIGHLIGHT_COLOR
+ self._highlighted = False
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ if len(xFiltered) == 0:
+ return None # No data to display, do not add renderer to backend
+
+ return backend.addCurve(xFiltered, yFiltered, self.getLegend(),
+ color=self.getCurrentColor(),
+ symbol=self.getSymbol(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=xerror,
+ yerror=yerror,
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize())
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getXData(copy=False)
+ elif item == 1:
+ return self.getYData(copy=False)
+ elif item == 2:
+ return self.getLegend()
+ elif item == 3:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'color': self.getColor(),
+ 'symbol': self.getSymbol(),
+ 'linewidth': self.getLineWidth(),
+ 'linestyle': self.getLineStyle(),
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ 'yaxis': self.getYAxis(),
+ 'xerror': self.getXErrorData(copy=False),
+ 'yerror': self.getYErrorData(copy=False),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'fill': self.isFill()
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s", str(item))
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visibleChanged = self.isVisible() != bool(visible)
+ super(Curve, self).setVisible(visible)
+
+ # TODO hackish data range implementation
+ if visibleChanged:
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ def isHighlighted(self):
+ """Returns True if curve is highlighted.
+
+ :rtype: bool
+ """
+ return self._highlighted
+
+ def setHighlighted(self, highlighted):
+ """Set the highlight state of the curve
+
+ :param bool highlighted:
+ """
+ highlighted = bool(highlighted)
+ if highlighted != self._highlighted:
+ self._highlighted = highlighted
+ # TODO inefficient: better to use backend's setCurveColor
+ self._updated()
+
+ def getHighlightedColor(self):
+ """Returns the RGBA highlight color of the item
+
+ :rtype: 4-tuple of int in [0, 255]
+ """
+ return self._highlightColor
+
+ def setHighlightedColor(self, color):
+ """Set the color to use when highlighted
+
+ :param color: color(s) to be used for highlight
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in Colors.py
+ """
+ color = Colors.rgba(color)
+ if color != self._highlightColor:
+ self._highlightColor = color
+ self._updated()
+
+ def getCurrentColor(self):
+ """Returns the current color of the curve.
+
+ This color is either the color of the curve or the highlighted color,
+ depending on the highlight state.
+
+ :rtype: 4-tuple of int in [0, 255]
+ """
+ if self.isHighlighted():
+ return self.getHighlightedColor()
+ else:
+ return self.getColor()
diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py
new file mode 100644
index 0000000..c3821bc
--- /dev/null
+++ b/silx/gui/plot/items/histogram.py
@@ -0,0 +1,288 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Histogram` item of the :class:`Plot`.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/05/2017"
+
+
+import logging
+
+import numpy
+
+from .core import (Item, AlphaMixIn, ColorMixIn, FillMixIn,
+ LineMixIn, YAxisMixIn)
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _computeEdges(x, histogramType):
+ """Compute the edges from a set of xs and a rule to generate the edges
+
+ :param x: the x value of the curve to transform into an histogram
+ :param histogramType: the type of histogram we wan't to generate.
+ This define the way to center the histogram values compared to the
+ curve value. Possible values can be::
+
+ - 'left'
+ - 'right'
+ - 'center'
+
+ :return: the edges for the given x and the histogramType
+ """
+ # for now we consider that the spaces between xs are constant
+ edges = x.copy()
+ if histogramType is 'left':
+ width = 1
+ if len(x) > 1:
+ width = x[1] - x[0]
+ edges = numpy.append(x[0] - width, edges)
+ if histogramType is 'center':
+ edges = _computeEdges(edges, 'right')
+ widths = (edges[1:] - edges[0:-1]) / 2.0
+ widths = numpy.append(widths, widths[-1])
+ edges = edges - widths
+ if histogramType is 'right':
+ width = 1
+ if len(x) > 1:
+ width = x[-1] - x[-2]
+ edges = numpy.append(edges, x[-1] + width)
+
+ return edges
+
+
+def _getHistogramCurve(histogram, edges):
+ """Returns the x and y value of a curve corresponding to the histogram
+
+ :param numpy.ndarray histogram: The values of the histogram
+ :param numpy.ndarray edges: The bin edges of the histogram
+ :return: a tuple(x, y) which contains the value of the curve to use
+ to display the histogram
+ """
+ assert len(histogram) + 1 == len(edges)
+ x = numpy.empty(len(histogram) * 2, dtype=edges.dtype)
+ y = numpy.empty(len(histogram) * 2, dtype=histogram.dtype)
+ # Make a curve with stairs
+ x[:-1:2] = edges[:-1]
+ x[1::2] = edges[1:]
+ y[:-1:2] = histogram
+ y[1::2] = histogram
+
+ return x, y
+
+
+# TODO: Yerror, test log scale
+class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
+ LineMixIn, YAxisMixIn):
+ """Description of an histogram"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for histograms"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state for histograms"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the histogram"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the histogram"""
+
+ def __init__(self):
+ Item.__init__(self)
+ AlphaMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+
+ self._histogram = ()
+ self._edges = ()
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ values, edges = self.getData(copy=False)
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer to backend
+
+ x, y = _getHistogramCurve(values, edges)
+
+ # Filter-out values <= 0
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.isXAxisLogarithmic()
+ yPositive = plot.isYAxisLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ clipped = numpy.logical_or(
+ (x <= 0) if xPositive else False,
+ (y <= 0) if yPositive else False)
+ # Make a copy and replace negative points by NaN
+ x = numpy.array(x, dtype=numpy.float)
+ y = numpy.array(y, dtype=numpy.float)
+ x[clipped] = numpy.nan
+ y[clipped] = numpy.nan
+
+ return backend.addCurve(x, y, self.getLegend(),
+ color=self.getColor(),
+ symbol='',
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=None,
+ yerror=None,
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=1)
+
+ def _getBounds(self):
+ values, edges = self.getData(copy=False)
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.isXAxisLogarithmic()
+ yPositive = plot.isYAxisLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ values = numpy.array(values, copy=True, dtype=numpy.float)
+
+ if xPositive:
+ # Replace edges <= 0 by NaN and corresponding values by NaN
+ clipped = (edges <= 0)
+ edges = numpy.array(edges, copy=True, dtype=numpy.float)
+ edges[clipped] = numpy.nan
+ values[numpy.logical_or(clipped[:-1], clipped[1:])] = numpy.nan
+
+ if yPositive:
+ # Replace values <= 0 by NaN, do not modify edges
+ values[values <= 0] = numpy.nan
+
+ if xPositive or yPositive:
+ return (numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ numpy.nanmin(values),
+ numpy.nanmax(values))
+
+ else: # No log scale, include 0 in bounds
+ return (numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ min(0, numpy.nanmin(values)),
+ max(0, numpy.nanmax(values)))
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visibleChanged = self.isVisible() != bool(visible)
+ super(Histogram, self).setVisible(visible)
+
+ # TODO hackish data range implementation
+ if visibleChanged:
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ def getValueData(self, copy=True):
+ """The values of the histogram
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The bin edges of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._histogram, copy=copy)
+
+ def getBinEdgesData(self, copy=True):
+ """The bin edges of the histogram (number of histogram values + 1)
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The bin edges of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._edges, copy=copy)
+
+ def getData(self, copy=True):
+ """Return the histogram values and the bin edges
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: (N histogram value, N+1 bin edges)
+ :rtype: 2-tuple of numpy.nadarray
+ """
+ return (self.getValueData(copy), self.getBinEdgesData(copy))
+
+ def setData(self, histogram, edges, align='center', copy=True):
+ """Set the histogram values and bin edges.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ histogram = numpy.array(histogram, copy=copy)
+ edges = numpy.array(edges, copy=copy)
+
+ assert histogram.ndim == 1
+ assert edges.ndim == 1
+ assert edges.size in (histogram.size, histogram.size + 1)
+ assert align in ('center', 'left', 'right')
+
+ if histogram.size == 0: # No data
+ self._histogram = ()
+ self._edges = ()
+ else:
+ if edges.size == histogram.size: # Compute true bin edges
+ edges = _computeEdges(edges, align)
+
+ # Check that bin edges are monotonic
+ edgesDiff = numpy.diff(edges)
+ assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0)
+
+ self._histogram = histogram
+ self._edges = edges
diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py
new file mode 100644
index 0000000..7e1dd8b
--- /dev/null
+++ b/silx/gui/plot/items/image.py
@@ -0,0 +1,385 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageData` and :class:`ImageRgba` items
+of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+from collections import Sequence
+import logging
+
+import numpy
+
+from .core import Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, AlphaMixIn
+from ..Colors import applyColormapToData
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _convertImageToRgba32(image, copy=True):
+ """Convert an RGB or RGBA image to RGBA32.
+
+ It converts from floats in [0, 1], bool, integer and uint in [0, 255]
+
+ If the input image is already an RGBA32 image,
+ the returned image shares the same data.
+
+ :param image: Image to convert to
+ :type image: numpy.ndarray with 3 dimensions: height, width, color channels
+ :param bool copy: True (Default) to get a copy, False, avoid copy if possible
+ :return: The image converted to RGBA32 with dimension: (height, width, 4)
+ :rtype: numpy.ndarray of uint8
+ """
+ assert image.ndim == 3
+ assert image.shape[-1] in (3, 4)
+
+ # Convert type to uint8
+ if image.dtype.name != 'uin8':
+ if image.dtype.kind == 'f': # Float in [0, 1]
+ image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8)
+ elif image.dtype.kind == 'b': # boolean
+ image = image.astype(numpy.uint8) * 255
+ elif image.dtype.kind in ('i', 'u'): # int, uint
+ image = numpy.clip(image, 0, 255).astype(numpy.uint8)
+ else:
+ raise ValueError('Unsupported image dtype: %s', image.dtype.name)
+ copy = False # A copy as already been done, avoid next one
+
+ # Convert RGB to RGBA
+ if image.shape[-1] == 3:
+ new_image = numpy.empty((image.shape[0], image.shape[1], 4),
+ dtype=numpy.uint8)
+ new_image[:, :, :3] = image
+ new_image[:, :, 3] = 255
+ return new_image # This is a copy anyway
+ else:
+ return numpy.array(image, copy=copy)
+
+
+class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
+ """Description of an image"""
+
+ def __init__(self):
+ Item.__init__(self)
+ LabelsMixIn.__init__(self)
+ DraggableMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ self._data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ self._origin = (0., 0.)
+ self._scale = (1., 1.)
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getData(copy=False)
+ elif item == 1:
+ return self.getLegend()
+ elif item == 2:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 3:
+ return None
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'origin': self.getOrigin(),
+ 'scale': self.getScale(),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'draggable': self.isDraggable(),
+ 'colormap': None,
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s" % str(item))
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visibleChanged = self.isVisible() != bool(visible)
+ super(ImageBase, self).setVisible(visible)
+
+ # TODO hackish data range implementation
+ if visibleChanged:
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ def _getBounds(self):
+ if self.getData(copy=False).size == 0: # Empty data
+ return None
+
+ height, width = self.getData(copy=False).shape[:2]
+ origin = self.getOrigin()
+ scale = self.getScale()
+ # Taking care of scale might be < 0
+ xmin, xmax = origin[0], origin[0] + width * scale[0]
+ if xmin > xmax:
+ xmin, xmax = xmax, xmin
+ # Taking care of scale might be < 0
+ ymin, ymax = origin[1], origin[1] + height * scale[1]
+ if ymin > ymax:
+ ymin, ymax = ymax, ymin
+
+ plot = self.getPlot()
+ if (plot is not None and
+ plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic()):
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ def getData(self, copy=True):
+ """Returns the image data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ raise NotImplementedError('This MUST be implemented in sub-class')
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._origin
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ if isinstance(origin, Sequence):
+ origin = float(origin[0]), float(origin[1])
+ else: # single value origin
+ origin = float(origin), float(origin)
+ if origin != self._origin:
+ self._origin = origin
+ self._updated()
+
+ # TODO hackish data range implementation
+ if self.isVisible():
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ if isinstance(scale, Sequence):
+ scale = float(scale[0]), float(scale[1])
+ else: # single value scale
+ scale = float(scale), float(scale)
+ if scale != self._scale:
+ self._scale = scale
+ self._updated()
+
+
+class ImageData(ImageBase, ColormapMixIn):
+ """Description of a data image with a colormap"""
+
+ def __init__(self):
+ ImageBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ self._data = numpy.zeros((0, 0), dtype=numpy.float32)
+ self._alternativeImage = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic():
+ return None # Do not render with log scales
+
+ if self.getAlternativeImageData(copy=False) is not None:
+ dataToUse = self.getAlternativeImageData(copy=False)
+ else:
+ dataToUse = self.getData(copy=False)
+
+ if dataToUse.size == 0:
+ return None # No data to display
+
+ return backend.addImage(dataToUse,
+ legend=self.getLegend(),
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ draggable=self.isDraggable(),
+ colormap=self.getColormap(),
+ alpha=self.getAlpha())
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if item == 3:
+ return self.getAlternativeImageData(copy=False)
+
+ params = ImageBase.__getitem__(self, item)
+ if item == 4:
+ params['colormap'] = self.getColormap()
+
+ return params
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ if self._alternativeImage is not None:
+ return _convertImageToRgba32(
+ self.getAlternativeImageData(copy=False), copy=copy)
+ else:
+ # Apply colormap, in this case an new array is always returned
+ colormap = self.getColormap()
+ image = applyColormapToData(self.getData(copy=False),
+ **colormap)
+ return image
+
+ def getAlternativeImageData(self, copy=True):
+ """Get the optional RGBA image that is displayed instead of the data
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: None or numpy.ndarray
+ :rtype: numpy.ndarray or None
+ """
+ if self._alternativeImage is None:
+ return None
+ else:
+ return numpy.array(self._alternativeImage, copy=copy)
+
+ def setData(self, data, alternative=None, copy=True):
+ """"Set the image data and optionally an alternative RGB(A) representation
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param alternative: RGB(A) image to display instead of data,
+ shape: (h, w, 3 or 4)
+ :type alternative: None or numpy.ndarray
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ self._data = data
+
+ if alternative is not None:
+ alternative = numpy.array(alternative, copy=copy)
+ assert alternative.ndim == 3
+ assert alternative.shape[2] in (3, 4)
+ assert alternative.shape[:2] == data.shape[:2]
+ self._alternativeImage = alternative
+ self._updated()
+
+ # TODO hackish data range implementation
+ if self.isVisible():
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+
+class ImageRgba(ImageBase):
+ """Description of an RGB(A) image"""
+
+ def __init__(self):
+ ImageBase.__init__(self)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if plot.isXAxisLogarithmic() or plot.isYAxisLogarithmic():
+ return None # Do not render with log scales
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(data,
+ legend=self.getLegend(),
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ draggable=self.isDraggable(),
+ colormap=None,
+ alpha=self.getAlpha())
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ return _convertImageToRgba32(self.getData(copy=False), copy=copy)
+
+ def setData(self, data, copy=True):
+ """Set the image data
+
+ :param data: RGB(A) image data to set
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ self._data = data
+
+ self._updated()
+
+ # TODO hackish data range implementation
+ if self.isVisible():
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
diff --git a/silx/gui/plot/items/marker.py b/silx/gui/plot/items/marker.py
new file mode 100644
index 0000000..c05558b
--- /dev/null
+++ b/silx/gui/plot/items/marker.py
@@ -0,0 +1,241 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides markers item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+import logging
+
+from .core import Item, DraggableMixIn, ColorMixIn, SymbolMixIn
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _BaseMarker(Item, DraggableMixIn, ColorMixIn):
+ """Base class for markers"""
+
+ _DEFAULT_COLOR = (0., 0., 0., 1.)
+ """Default color of the markers"""
+
+ def __init__(self):
+ Item.__init__(self)
+ DraggableMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+
+ self._text = ''
+ self._x = None
+ self._y = None
+ self._constraint = self._defaultConstraint
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # TODO not very nice way to do it, but simple
+ symbol = self.getSymbol() if isinstance(self, Marker) else None
+
+ return backend.addMarker(
+ x=self.getXPosition(),
+ y=self.getYPosition(),
+ legend=self.getLegend(),
+ text=self.getText(),
+ color=self.getColor(),
+ selectable=self.isSelectable(),
+ draggable=self.isDraggable(),
+ symbol=symbol,
+ constraint=self.getConstraint(),
+ overlay=self.isOverlay())
+
+ def isOverlay(self):
+ """Return true if marker is drawn as an overlay.
+
+ A marker is an overlay if it is draggable.
+
+ :rtype: bool
+ """
+ return self.isDraggable()
+
+ def getText(self):
+ """Returns marker text.
+
+ :rtype: str
+ """
+ return self._text
+
+ def setText(self, text):
+ """Set the text of the marker.
+
+ :param str text: The text to use
+ """
+ text = str(text)
+ if text != self._text:
+ self._text = text
+ self._updated()
+
+ def getXPosition(self):
+ """Returns the X position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._x
+
+ def getYPosition(self):
+ """Returns the Y position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._y
+
+ def getPosition(self):
+ """Returns the (x, y) position of the marker in data coordinates
+
+ :rtype: 2-tuple of float or None
+ """
+ return self._x, self._y
+
+ def setPosition(self, x, y):
+ """Set marker position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, y = self.getConstraint()(x, y)
+ x, y = float(x), float(y)
+ if x != self._x or y != self._y:
+ self._x, self._y = x, y
+ self._updated()
+
+ def getConstraint(self):
+ """Returns the dragging constraint of this item"""
+ return self._constraint
+
+ def _setConstraint(self, constraint): # TODO support update
+ """Set the constraint.
+
+ This is private for now as update is not handled.
+
+ :param callable constraint:
+ :param constraint: A function filtering item displacement by
+ dragging operations or None for no filter.
+ This function is called each time the item is
+ moved.
+ This is only used if isDraggable returns True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ """
+ if constraint is None:
+ constraint = self._defaultConstraint
+ assert callable(constraint)
+ self._constraint = constraint
+
+ @staticmethod
+ def _defaultConstraint(*args):
+ """Default constraint not doing anything"""
+ return args
+
+
+class Marker(_BaseMarker, SymbolMixIn):
+ """Description of a marker"""
+
+ _DEFAULT_SYMBOL = '+'
+ """Default symbol of the marker"""
+
+ def __init__(self):
+ _BaseMarker.__init__(self)
+ SymbolMixIn.__init__(self)
+
+ self._x = 0.
+ self._y = 0.
+
+ def _setConstraint(self, constraint):
+ """Set the constraint function of the marker drag.
+
+ It also supports 'horizontal' and 'vertical' str as constraint.
+
+ :param constraint: The constraint of the dragging of this marker
+ :type: constraint: callable or str
+ """
+ if constraint == 'horizontal':
+ constraint = self._horizontalConstraint
+ elif constraint == 'vertical':
+ constraint = self._verticalConstraint
+
+ super(Marker, self)._setConstraint(constraint)
+
+ def _horizontalConstraint(self, _, y):
+ return self.getXPosition(), y
+
+ def _verticalConstraint(self, x, _):
+ return x, self.getYPosition()
+
+
+class XMarker(_BaseMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _BaseMarker.__init__(self)
+ self._x = 0.
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, _ = self.getConstraint()(x, y)
+ x = float(x)
+ if x != self._x:
+ self._x = x
+ self._updated()
+
+
+class YMarker(_BaseMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _BaseMarker.__init__(self)
+ self._y = 0.
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ _, y = self.getConstraint()(x, y)
+ y = float(y)
+ if y != self._y:
+ self._y = y
+ self._updated()
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
new file mode 100644
index 0000000..3897dc1
--- /dev/null
+++ b/silx/gui/plot/items/scatter.py
@@ -0,0 +1,169 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Scatter` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "29/03/2017"
+
+
+import logging
+
+import numpy
+
+from .core import Points, ColormapMixIn
+from silx.gui.plot.Colors import applyColormapToData # TODO: cherry-pick commit or wait for PR merge
+
+_logger = logging.getLogger(__name__)
+
+
+class Scatter(Points, ColormapMixIn):
+ """Description of a scatter"""
+ _DEFAULT_SYMBOL = 'o'
+ """Default symbol of the scatter plots"""
+
+ def __init__(self):
+ Points.__init__(self)
+ ColormapMixIn.__init__(self)
+ self._value = ()
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ if len(xFiltered) == 0:
+ return None # No data to display, do not add renderer to backend
+
+ cmap = self.getColormap()
+ rgbacolors = applyColormapToData(self._value,
+ cmap["name"],
+ cmap["normalization"],
+ cmap["autoscale"],
+ cmap["vmin"],
+ cmap["vmax"],
+ cmap.get("colors"))
+
+ return backend.addCurve(xFiltered, yFiltered, self.getLegend(),
+ color=rgbacolors,
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=xerror,
+ yerror=yerror,
+ z=self.getZValue(),
+ selectable=self.isSelectable(),
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize())
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filtered arrays or unchanged object if not filtering needed
+ :rtype: (x, y, value, xerror, yerror)
+ """
+ # overloaded from Points to filter also value.
+ value = self.getValueData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ value = numpy.array(value, copy=True, dtype=numpy.float)
+ value[clipped] = numpy.nan
+
+ x, y, xerror, yerror = Points._logFilterData(self, xPositive, yPositive)
+
+ return x, y, value, xerror, yerror
+
+ def getValueData(self, copy=True):
+ """Returns the value assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._value, copy=copy)
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y coordinates and the value of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False.
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, value, xerror, yerror)
+ :rtype: 5-tuple of numpy.ndarray
+ """
+ if displayed:
+ data = self._getCachedData()
+ if data is not None:
+ assert len(data) == 5
+ return data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getValueData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ # reimplemented from Points to handle `value`
+ def setData(self, x, y, value, xerror=None, yerror=None, copy=True):
+ """Set the data of the scatter.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param numpy.ndarray value: The data corresponding to the value of
+ the data points.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ value = numpy.array(value, copy=copy)
+ assert value.ndim == 1
+ assert len(x) == len(value)
+
+ self._value = value
+
+ # set x, y, xerror, yerror
+
+ # call self._updated + plot._invalidateDataRange()
+ Points.setData(self, x, y, xerror, yerror, copy)
diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py
new file mode 100644
index 0000000..b663989
--- /dev/null
+++ b/silx/gui/plot/items/shape.py
@@ -0,0 +1,121 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Shape` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+import logging
+
+import numpy
+
+from .core import Item, ColorMixIn, FillMixIn
+
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO probably make one class for each kind of shape
+# TODO check fill:polygon/polyline + fill = duplicated
+class Shape(Item, ColorMixIn, FillMixIn):
+ """Description of a shape item
+
+ :param str type_: The type of shape in:
+ 'hline', 'polygon', 'rectangle', 'vline', 'polyline'
+ """
+
+ def __init__(self, type_):
+ Item.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ self._overlay = False
+ assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polyline')
+ self._type = type_
+ self._points = ()
+
+ self._handle = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ points = self.getPoints(copy=False)
+ x, y = points.T[0], points.T[1]
+ return backend.addItem(x,
+ y,
+ legend=self.getLegend(),
+ shape=self.getType(),
+ color=self.getColor(),
+ fill=self.isFill(),
+ overlay=self.isOverlay(),
+ z=self.getZValue())
+
+ def isOverlay(self):
+ """Return true if shape is drawn as an overlay
+
+ :rtype: bool
+ """
+ return self._overlay
+
+ def setOverlay(self, overlay):
+ """Set the overlay state of the shape
+
+ :param bool overlay: True to make it an overlay
+ """
+ overlay = bool(overlay)
+ if overlay != self._overlay:
+ self._overlay = overlay
+ self._updated()
+
+ def getType(self):
+ """Returns the type of shape to draw.
+
+ One of: 'hline', 'polygon', 'rectangle', 'vline', 'polyline'
+
+ :rtype: str
+ """
+ return self._type
+
+ def getPoints(self, copy=True):
+ """Get the control points of the shape.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return: Array of point coordinates
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ return numpy.array(self._points, copy=copy)
+
+ def setPoints(self, points, copy=True):
+ """Set the point coordinates
+
+ :param numpy.ndarray points: Array of point coordinates
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return:
+ """
+ self._points = numpy.array(points, copy=copy)
+ self._updated()
diff --git a/silx/gui/plot/setup.py b/silx/gui/plot/setup.py
new file mode 100644
index 0000000..6408113
--- /dev/null
+++ b/silx/gui/plot/setup.py
@@ -0,0 +1,47 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/02/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('plot', parent_package, top_path)
+ config.add_subpackage('_utils')
+ config.add_subpackage('backends')
+ config.add_subpackage('backends.glutils')
+ config.add_subpackage('items')
+ config.add_subpackage('test')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py
new file mode 100644
index 0000000..b4378c7
--- /dev/null
+++ b/silx/gui/plot/test/__init__.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import unittest
+
+from .._utils.test import suite as testUtilsSuite
+from .testColorBar import suite as testColorBarSuite
+from .testColormapDialog import suite as testColormapDialogSuite
+from .testColors import suite as testColorsSuite
+from .testCurvesROIWidget import suite as testCurvesROIWidgetSuite
+from .testAlphaSlider import suite as testAlphaSliderSuite
+from .testInteraction import suite as testInteractionSuite
+from .testLegendSelector import suite as testLegendSelectorSuite
+from .testMaskToolsWidget import suite as testMaskToolsWidgetSuite
+from .testScatterMaskToolsWidget import suite as testScatterMaskToolsWidgetSuite
+from .testPlotInteraction import suite as testPlotInteractionSuite
+from .testPlotTools import suite as testPlotToolsSuite
+from .testPlotWidget import suite as testPlotWidgetSuite
+from .testPlotWindow import suite as testPlotWindowSuite
+from .testPlot import suite as testPlotSuite
+from .testProfile import suite as testProfileSuite
+from .testStackView import suite as testStackViewSuite
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTests(
+ [testUtilsSuite(),
+ testColorBarSuite(),
+ testColorsSuite(),
+ testColormapDialogSuite(),
+ testCurvesROIWidgetSuite(),
+ testAlphaSliderSuite(),
+ testInteractionSuite(),
+ testLegendSelectorSuite(),
+ testMaskToolsWidgetSuite(),
+ testScatterMaskToolsWidgetSuite(),
+ testPlotInteractionSuite(),
+ testPlotSuite(),
+ testPlotToolsSuite(),
+ testPlotWidgetSuite(),
+ testPlotWindowSuite(),
+ testProfileSuite(),
+ testStackViewSuite()])
+ return test_suite
diff --git a/silx/gui/plot/test/testAlphaSlider.py b/silx/gui/plot/test/testAlphaSlider.py
new file mode 100644
index 0000000..304a562
--- /dev/null
+++ b/silx/gui/plot/test/testAlphaSlider.py
@@ -0,0 +1,221 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ImageAlphaSlider"""
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/03/2017"
+
+import numpy
+import unittest
+
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot import AlphaSlider
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+class TestActiveImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestActiveImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestActiveImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no active image initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ # now we have an active image
+ self.assertTrue(self.aslider.isEnabled())
+
+ self.plot.setActiveImage(None)
+ self.assertFalse(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ self.assertEqual(self.plot.getActiveImage(),
+ self.aslider.getItem())
+
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.plot.setActiveImage("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setValue(137)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 137. / 255)
+
+
+class TestNamedImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no image set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getImage("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
+
+
+class TestNamedScatterAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedScatterAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedScatterAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no Scatter set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetScatter(self):
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getScatter("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getScatter("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ # test_suite.addTest(positionInfoTestSuite)
+ for testClass in (TestActiveImageAlphaSlider, TestNamedImageAlphaSlider,
+ TestNamedScatterAlphaSlider):
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
+ testClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testColorBar.py b/silx/gui/plot/test/testColorBar.py
new file mode 100644
index 0000000..797ff03
--- /dev/null
+++ b/silx/gui/plot/test/testColorBar.py
@@ -0,0 +1,240 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColorBar featues and sub widgets of Colorbar module"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "11/04/2017"
+
+import unittest
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.plot.ColorBar import _ColorScale
+from silx.gui.plot.ColorBar import ColorBarWidget
+from silx.gui.plot import Plot2D
+import numpy
+
+
+class TestColorScale(unittest.TestCase):
+ """Test that interaction with the colorScale is correct"""
+ def setUp(self):
+ self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
+
+ def tearDown(self):
+ self.colorScaleWidget.deleteLater()
+ self.colorScaleWidget = None
+
+ def testRelativePositionLinear(self):
+ self.colorMapLin1 = { 'name': 'gray', 'normalization': 'linear',
+ 'autoscale': False, 'vmin': 0.0, 'vmax': 1.0 }
+ self.colorScaleWidget.setColormap(self.colorMapLin1)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
+
+ self.colorMapLin2 = { 'name': 'viridis', 'normalization': 'linear',
+ 'autoscale': False, 'vmin': -10, 'vmax': 0 }
+ self.colorScaleWidget.setColormap(self.colorMapLin2)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
+
+ def testRelativePositionLog(self):
+ self.colorMapLog1 = { 'name': 'temperature', 'normalization': 'log',
+ 'autoscale': False, 'vmin': 1.0, 'vmax': 100.0 }
+
+ self.colorScaleWidget.setColormap(self.colorMapLog1)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 100.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.5)
+ self.assertTrue(val == 10.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+ def testNegativeLogMin(self):
+ colormap = { 'name': 'gray', 'normalization': 'log',
+ 'autoscale': False, 'vmin': -1.0, 'vmax': 1.0 }
+
+ with self.assertRaises(ValueError):
+ self.colorScaleWidget.setColormap(colormap)
+
+ def testNegativeLogMax(self):
+ colormap = { 'name': 'gray', 'normalization': 'log',
+ 'autoscale': False, 'vmin': 1.0, 'vmax': -1.0 }
+
+ with self.assertRaises(ValueError):
+ self.colorScaleWidget.setColormap(colormap)
+
+class TestNoAutoscale(unittest.TestCase):
+ """Test that ticks and color displayed are correct in the case of a colormap
+ with no autoscale
+ """
+
+ def setUp(self):
+ self.plot = Plot2D()
+ self.colorBar = ColorBarWidget(parent=None, plot=self.plot)
+ self.tickBar = self.colorBar.getColorScaleBar().getTickBar()
+ self.colorScale = self.colorBar.getColorScaleBar().getColorScale()
+
+ def tearDown(self):
+ self.tickBar = None
+ self.colorScale = None
+ del self.colorBar
+ self.plot.close()
+ del self.plot
+
+ def testLogNormNoAutoscale(self):
+ colormapLog = { 'name': 'gray', 'normalization': 'log',
+ 'autoscale': False, 'vmin': 1.0, 'vmax': 100.0 }
+
+ data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ ticksTh = numpy.linspace(1.0, 100.0, 10)
+ ticksTh = 10**ticksTh
+ numpy.array_equal(self.tickBar.ticks, ticksTh)
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 100.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+ def testLinearNormNoAutoscale(self):
+ colormapLog = { 'name': 'gray', 'normalization': 'linear',
+ 'autoscale': False, 'vmin': -4, 'vmax': 5 }
+
+ data = numpy.linspace(1, 9, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10))
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 5.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == -4.0)
+
+class TestColorbarWidget(TestCaseQt):
+ """Test interaction with the ColorScaleBar"""
+
+ def setUp(self):
+ super(TestColorbarWidget, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = ColorBarWidget(parent=None, plot=self.plot)
+
+ def tearDown(self):
+ del self.colorBar
+ self.plot.close()
+ del self.plot
+
+ super(TestColorbarWidget, self).tearDown()
+
+ def testEmptyColorBar(self):
+ colorBar = ColorBarWidget(parent=None)
+ colorBar.show()
+ self.qWaitForWindowExposed(colorBar)
+
+ def testNegativeColormaps(self):
+ """test the behavior of the ColorBarWidget in the case of negative
+ values
+
+ Note : colorbar is modified by the Plot directly not ColorBarWidget
+ """
+ colormapLog = { 'name': 'gray', 'normalization': 'log',
+ 'autoscale': True, 'vmin': -1.0, 'vmax': 1.0 }
+
+ colormapLog2 = { 'name': 'gray', 'normalization': 'log',
+ 'autoscale': False, 'vmin': -1.0, 'vmax': 1.0 }
+
+ data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
+ data = data.reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # default behavior when autoscale : set to minmal positive value
+ data[data<1] = data.max()
+ self.assertTrue(self.colorBar._colormap['vmin'] == data.min())
+ self.assertTrue(self.colorBar._colormap['vmax'] == data.max())
+
+ data = numpy.linspace(-9, -2, 100).reshape(10, 10)
+
+ self.plot.addImage(data=data, colormap=colormapLog2, legend='toto')
+ self.plot.setActiveImage('toto')
+ # if negative values, changing bounds for default : 1, 10
+ self.assertTrue(self.colorBar._colormap['vmin'] == 1)
+ self.assertTrue(self.colorBar._colormap['vmax'] == 10)
+
+ def testPlotAssocation(self):
+ """Make sure the ColorBarWidget is proparly connected with the plot"""
+ colormap = { 'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': -1.0, 'vmax': 1.0 }
+
+ # make sure that default settings are the same
+ self.assertTrue(
+ self.colorBar.getColormap() == self.plot.getDefaultColormap())
+
+ data = numpy.linspace(0, 10, 100).reshape(10, 10)
+ self.plot.addImage(data=data, colormap=colormap, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # make sure the modification of the colormap has been done
+ self.assertFalse(
+ self.colorBar.getColormap() == self.plot.getDefaultColormap())
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for ui in (TestColorScale, TestNoAutoscale, TestColorbarWidget):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(ui))
+
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite') \ No newline at end of file
diff --git a/silx/gui/plot/test/testColormapDialog.py b/silx/gui/plot/test/testColormapDialog.py
new file mode 100644
index 0000000..d016548
--- /dev/null
+++ b/silx/gui/plot/test/testColormapDialog.py
@@ -0,0 +1,68 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColormapDialog"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import doctest
+import unittest
+
+from silx.gui.test.utils import qWaitForWindowExposedAndActivate
+from silx.gui import qt
+from silx.gui.plot import ColormapDialog
+
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+def _tearDownQt(docTest):
+ """Tear down to use for test from docstring.
+
+ Checks that dialog widget is displayed
+ """
+ dialogWidget = docTest.globs['dialog']
+ qWaitForWindowExposedAndActivate(dialogWidget)
+ dialogWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ dialogWidget.close()
+ del dialogWidget
+ _qapp.processEvents()
+
+
+cmapDocTestSuite = doctest.DocTestSuite(ColormapDialog, tearDown=_tearDownQt)
+"""Test suite of tests from the module's docstrings."""
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(cmapDocTestSuite)
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testColors.py b/silx/gui/plot/test/testColors.py
new file mode 100644
index 0000000..94c22f3
--- /dev/null
+++ b/silx/gui/plot/test/testColors.py
@@ -0,0 +1,94 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for Colors"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import numpy
+
+import unittest
+from silx.test.utils import ParametricTestCase
+
+from silx.gui.plot import Colors
+
+
+class TestRGBA(ParametricTestCase):
+ """Basic tests of rgba function"""
+
+ def testRGBA(self):
+ """"Test rgba function with accepted values"""
+ tests = { # name: (colors, expected values)
+ 'blue': ('blue', (0., 0., 1., 1.)),
+ '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)),
+ '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)),
+ '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1.)),
+ '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1 / 255.)),
+ '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)),
+ }
+
+ for name, test in tests.items():
+ color, expected = test
+ with self.subTest(msg=name):
+ result = Colors.rgba(color)
+ self.assertEqual(result, expected)
+
+
+class TestApplyColormapToData(ParametricTestCase):
+ """Tests of applyColormapToData function"""
+
+ def testApplyColormapToData(self):
+ """Simple test of applyColormapToData function"""
+ colormap = dict(name='gray', normalization='linear',
+ autoscale=False, vmin=0, vmax=255)
+
+ size = 10
+ expected = numpy.empty((size, 4), dtype='uint8')
+ expected[:, 0] = numpy.arange(size, dtype='uint8')
+ expected[:, 1] = expected[:, 0]
+ expected[:, 2] = expected[:, 0]
+ expected[:, 3] = 255
+
+ for dtype in ('uint8', 'int32', 'float32', 'float64'):
+ with self.subTest(dtype=dtype):
+ array = numpy.arange(size, dtype=dtype)
+ result = Colors.applyColormapToData(array, **colormap)
+ self.assertTrue(numpy.all(numpy.equal(result, expected)))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for testClass in (TestRGBA, TestApplyColormapToData):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(testClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py
new file mode 100644
index 0000000..3c6f2ba
--- /dev/null
+++ b/silx/gui/plot/test/testCurvesROIWidget.py
@@ -0,0 +1,153 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.plot import PlotWindow, CurvesROIWidget
+
+
+logging.basicConfig()
+_logger = logging.getLogger(__name__)
+
+
+class TestCurvesROIWidget(TestCaseQt):
+ """Basic test for CurvesROIWidget"""
+
+ def setUp(self):
+ super(TestCurvesROIWidget, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST')
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+
+ super(TestCurvesROIWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display ROI widget"""
+ pass
+
+ def testWithCurves(self):
+ """Plot with curves: test all ROI widget buttons"""
+ for offset in range(2):
+ self.plot.addCurve(numpy.arange(1000),
+ offset + numpy.random.random(1000),
+ legend=str(offset))
+
+ # Add two ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+
+ # Change active curve
+ self.plot.setActiveCurve(str(1))
+
+ # Delete a ROI
+ self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+
+ with temp_dir() as tmpDir:
+ self.tmpFile = os.path.join(tmpDir, 'test.ini')
+
+ # Save ROIs
+ self.widget.roiWidget.save(self.tmpFile)
+ self.assertTrue(os.path.isfile(self.tmpFile))
+
+ # Reset ROIs
+ self.mouseClick(self.widget.roiWidget.resetButton,
+ qt.Qt.LeftButton)
+
+ # Load ROIs
+ self.widget.roiWidget.load(self.tmpFile)
+
+ del self.tmpFile
+
+ def testCalculation(self):
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ ddict = {}
+ ddict["positive"] = {"from": 10, "to": 20, "type":"X"}
+ ddict["negative"] = {"from": -20, "to": -10, "type":"X"}
+ self.widget.roiWidget.setRois(ddict)
+
+ # And calculate the expected output
+ self.widget.calculateROIs()
+
+ output = self.widget.roiWidget.getRois()
+ self.assertEqual(output["positive"]["rawcounts"],
+ y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(),
+ "Calculation failed on positive X coordinates")
+
+ # Set the curve with negative X coordinates as active
+ self.plot.setActiveCurve("negative")
+
+ # the ROIs should have been automatically updated
+ output = self.widget.roiWidget.getRois()
+ selection = numpy.nonzero((-x >= output["negative"]["from"]) & \
+ (-x <= output["negative"]["to"]))[0]
+ self.assertEqual(output["negative"]["rawcounts"],
+ y[selection].sum(), "Calculation failed on negative X coordinates")
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestCurvesROIWidget,):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testInteraction.py b/silx/gui/plot/test/testInteraction.py
new file mode 100644
index 0000000..074a7cd
--- /dev/null
+++ b/silx/gui/plot/test/testInteraction.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests from interaction state machines"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import unittest
+
+from silx.gui.plot import Interaction
+
+
+class TestInteraction(unittest.TestCase):
+ def testClickOrDrag(self):
+ """Minimalistic test for click or drag state machine."""
+ events = []
+
+ class TestClickOrDrag(Interaction.ClickOrDrag):
+ def click(self, x, y, btn):
+ events.append(('click', x, y, btn))
+
+ def beginDrag(self, x, y):
+ events.append(('beginDrag', x, y))
+
+ def drag(self, x, y):
+ events.append(('drag', x, y))
+
+ def endDrag(self, x, y):
+ events.append(('endDrag', x, y))
+
+ clickOrDrag = TestClickOrDrag()
+
+ # click
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+
+ clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 1)
+ self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN))
+
+ # drag
+ events = []
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+ clickOrDrag.handleEvent('move', 15, 10)
+ self.assertEqual(len(events), 2) # Received beginDrag and drag
+ self.assertEqual(events[0], ('beginDrag', 10, 10))
+ self.assertEqual(events[1], ('drag', 15, 10))
+ clickOrDrag.handleEvent('move', 20, 10)
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[-1], ('drag', 20, 10))
+ clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 4)
+ self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10)))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestInteraction))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testLegendSelector.py b/silx/gui/plot/test/testLegendSelector.py
new file mode 100644
index 0000000..371197f
--- /dev/null
+++ b/silx/gui/plot/test/testLegendSelector.py
@@ -0,0 +1,143 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import logging
+import unittest
+
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.plot import LegendSelector
+
+
+logging.basicConfig()
+_logger = logging.getLogger(__name__)
+
+
+class TestLegendSelector(TestCaseQt):
+ """Basic test for LegendSelector"""
+
+ def testLegendSelector(self):
+ """Test copied from __main__ of LegendSelector in PyMca"""
+ class Notifier(qt.QObject):
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self.chk = True
+
+ def signalReceived(self, **kw):
+ obj = self.sender()
+ _logger.info('NOTIFIER -- signal received\n\tsender: %s',
+ str(obj))
+
+ notifier = Notifier()
+
+ legends = ['Legend0',
+ 'Legend1',
+ 'Long Legend 2',
+ 'Foo Legend 3',
+ 'Even Longer Legend 4',
+ 'Short Leg 5',
+ 'Dot symbol 6',
+ 'Comma symbol 7']
+ colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan,
+ qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow]
+ symbols = ['o', 't', '+', 'x', 's', 'd', '.', ',']
+
+ win = LegendSelector.LegendListView()
+ # win = LegendListContextMenu()
+ # win = qt.QWidget()
+ # layout = qt.QVBoxLayout()
+ # layout.setContentsMargins(0,0,0,0)
+ llist = []
+
+ for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
+ ddict = {
+ 'color': qt.QColor(c),
+ 'linewidth': 4,
+ 'symbol': s,
+ }
+ legend = l
+ llist.append((legend, ddict))
+ # item = qt.QListWidgetItem(win)
+ # legendWidget = LegendListItemWidget(l)
+ # legendWidget.icon.setSymbol(s)
+ # legendWidget.icon.setColor(qt.QColor(c))
+ # layout.addWidget(legendWidget)
+ # win.setItemWidget(item, legendWidget)
+
+ # win = LegendListItemWidget('Some Legend 1')
+ # print(llist)
+ model = LegendSelector.LegendModel(legendList=llist)
+ win.setModel(model)
+ win.setSelectionModel(qt.QItemSelectionModel(model))
+ win.setContextMenu()
+ # print('Edit triggers: %d'%win.editTriggers())
+
+ # win = LegendListWidget(None, legends)
+ # win[0].updateItem(ddict)
+ # win.setLayout(layout)
+ win.sigLegendSignal.connect(notifier.signalReceived)
+ win.show()
+
+ win.clear()
+ win.setLegendList(llist)
+
+ self.qWaitForWindowExposed(win)
+
+
+class TestRenameCurveDialog(TestCaseQt):
+ """Basic test for RenameCurveDialog"""
+
+ def testDialog(self):
+ """Create dialog, change name and press OK"""
+ self.dialog = LegendSelector.RenameCurveDialog(
+ None, 'curve1', ['curve1', 'curve2', 'curve3'])
+ self.dialog.open()
+ self.qWaitForWindowExposed(self.dialog)
+ self.keyClicks(self.dialog.lineEdit, 'changed')
+ self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ ret = self.dialog.result()
+ self.assertEqual(ret, qt.QDialog.Accepted)
+ newName = self.dialog.getText()
+ self.assertEqual(newName, 'curve1changed')
+ del self.dialog
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestLegendSelector, TestRenameCurveDialog):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
new file mode 100644
index 0000000..0c11928
--- /dev/null
+++ b/silx/gui/plot/test/testMaskToolsWidget.py
@@ -0,0 +1,295 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir, ParametricTestCase
+from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, MaskToolsWidget
+
+try:
+ import fabio
+except ImportError:
+ fabio = None
+
+
+logging.basicConfig()
+_logger = logging.getLogger(__name__)
+
+
+class TestMaskToolsWidget(TestCaseQt, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def setUp(self):
+ super(TestMaskToolsWidget, self).setUp()
+ self.plot = PlotWindow()
+
+ self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ super(TestMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.centralWidget()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=pos0)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.mouseMove(plot, pos=pos1)
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.centralWidget()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ btn = qt.Qt.LeftButton if pos != star[-1] else qt.Qt.RightButton
+ self.mouseClick(plot, btn, pos=pos)
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.centralWidget()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=star[0])
+ self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.mouseRelease(
+ plot, qt.Qt.LeftButton, pos=star[-1])
+
+ def testWithAnImage(self):
+ """Plot with an image: test MaskToolsWidget interactions"""
+
+ # Add and remove a image (this should enable/disable GUI + change mask)
+ self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='image')
+ self.qapp.processEvents()
+
+ tests = [((0, 0), (1, 1)),
+ ((1000, 1000), (1, 1)),
+ ((0, 0), (-1, -1)),
+ ((1000, 1000), (-1, -1))]
+
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test',
+ origin=origin,
+ scale=scale)
+ self.qapp.processEvents()
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(10)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ """Plot with an image: test MaskToolsWidget operations"""
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveFit2D(self):
+ if fabio is None:
+ self.skipTest("Fabio is missing")
+ self.__loadSave("msk")
+
+ def testSigMaskChangedEmitted(self):
+ self.plot.addImage(numpy.arange(512**2).reshape(512, 512),
+ legend='test')
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestMaskToolsWidget,):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlot.py b/silx/gui/plot/test/testPlot.py
new file mode 100644
index 0000000..25e7511
--- /dev/null
+++ b/silx/gui/plot/test/testPlot.py
@@ -0,0 +1,633 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+from functools import reduce
+from silx.test.utils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot.Plot import Plot
+from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
+
+
+class TestPlot(unittest.TestCase):
+ """Basic tests of Plot without backend"""
+
+ def testPlotTitleLabels(self):
+ """Create a Plot and set the labels"""
+
+ plot = Plot(backend='none')
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ plot.setGraphTitle(title)
+ plot.setGraphXLabel(xlabel)
+ plot.setGraphYLabel(ylabel)
+
+ self.assertEqual(plot.getGraphTitle(), title)
+ self.assertEqual(plot.getGraphXLabel(), xlabel)
+ self.assertEqual(plot.getGraphYLabel(), ylabel)
+
+ def testAddNoRemove(self):
+ """add objects to the Plot"""
+
+ plot = Plot(backend='none')
+ plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
+ plot.addImage(numpy.arange(100.).reshape(10, -1))
+ plot.addItem(
+ numpy.array((1., 10.)), numpy.array((10., 10.)), shape="rectangle")
+ plot.addXMarker(10.)
+
+
+class TestPlotRanges(ParametricTestCase):
+ """Basic tests of Plot data ranges without backend"""
+
+ _getValidValues = {True: lambda ar: ar > 0,
+ False: lambda ar: numpy.ones(shape=ar.shape,
+ dtype=bool)}
+
+ @staticmethod
+ def _getRanges(arrays, are_logs):
+ gen = (TestPlotRanges._getValidValues[is_log](ar)
+ for (ar, is_log) in zip(arrays, are_logs))
+ indices = numpy.where(reduce(numpy.logical_and, gen))[0]
+ if len(indices) > 0:
+ ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
+ else:
+ ranges = [None] * len(arrays)
+
+ return ranges
+
+ @staticmethod
+ def _getRangesMinmax(ranges):
+ # TODO : error if None in ranges.
+ rangeMin = numpy.min([rng[0] for rng in ranges])
+ rangeMax = numpy.max([rng[1] for rng in ranges])
+ return rangeMin, rangeMax
+
+ def testDataRangeNoPlot(self):
+ """empty plot data range"""
+
+ plot = Plot(backend='none')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ self.assertIsNone(dataRange.x)
+ self.assertIsNone(dataRange.y)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeft(self):
+ """left axis range"""
+
+ plot = Plot(backend='none')
+
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='left')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertSequenceEqual(dataRange.y, yRange)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeRight(self):
+ """right axis range"""
+
+ plot = Plot(backend='none')
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertIsNone(dataRange.y)
+ self.assertSequenceEqual(dataRange.yright, yRange)
+
+ def testDataRangeImage(self):
+ """image data range"""
+
+ origin = (-10, 25)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = Plot(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeftRight(self):
+ """right+left axis range"""
+
+ plot = Plot(backend='none')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeCurveImage(self):
+ """right+left+image axis range"""
+
+ # overlapping ranges :
+ # image sets x min and y max
+ # plot_left sets y min
+ # plot_right sets x max (and yright)
+ plot = Plot(backend='none')
+
+ origin = (-10, 5)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot.addImage(image,
+ origin=origin, scale=scale, legend='image')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
+ yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ if logX or logY:
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ else:
+ xRangeLR = self._getRangesMinmax([xRangeL,
+ xRangeR,
+ imgXRange])
+ yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeImageNegativeScaleX(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (-3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = Plot(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ xRange.sort() # negative scale!
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeImageNegativeScaleY(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (3., -8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = Plot(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+ yRange.sort() # negative scale!
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.setXAxisLogarithmic(logX)
+ plot.setYAxisLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeHiddenCurve(self):
+ """curves with a hidden curve"""
+ plot = Plot(backend='none')
+ plot.addCurve((0, 1), (0, 1), legend='shown')
+ plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden')
+ range1 = plot.getDataRange()
+ self.assertEqual(range1.x, (0, 2))
+ self.assertEqual(range1.y, (0, 5))
+ plot.hideCurve('hidden')
+ range2 = plot.getDataRange()
+ self.assertEqual(range2.x, (0, 1))
+ self.assertEqual(range2.y, (0, 1))
+
+
+class TestPlotGetCurveImage(unittest.TestCase):
+ """Test of plot getCurve and getImage methods"""
+
+ def testGetCurve(self):
+ """Plot.getCurve and Plot.getActiveCurve tests"""
+
+ plot = Plot(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2')
+ plot.setActiveCurve('curve 0')
+
+ # Active curve
+ active = plot.getActiveCurve()
+ self.assertEqual(active.getLegend(), 'curve 0')
+ curve = plot.getCurve()
+ self.assertEqual(curve.getLegend(), 'curve 0')
+
+ # No active curve and curves
+ plot.setActiveCurveHandling(False)
+ active = plot.getActiveCurve()
+ self.assertIsNone(active) # No active curve
+ curve = plot.getCurve()
+ self.assertEqual(curve.getLegend(), 'curve 2') # Last added curve
+
+ # Last curve hidden
+ plot.hideCurve('curve 2', True)
+ curve = plot.getCurve()
+ self.assertEqual(curve.getLegend(), 'curve 1') # Last added curve
+
+ # All curves hidden
+ plot.hideCurve('curve 1', True)
+ plot.hideCurve('curve 0', True)
+ curve = plot.getCurve()
+ self.assertIsNone(curve)
+
+ def testGetCurveOldApi(self):
+ """old API Plot.getCurve and Plot.getActiveCurve tests"""
+
+ plot = Plot(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ x = numpy.arange(10.).astype(numpy.float32)
+ y = x * x;
+ plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"])
+ plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything")
+ plot.setActiveCurve('curve 0')
+
+ # Active curve (4 elements)
+ xOut, yOut, legend, info = plot.getActiveCurve()[:4]
+ self.assertEqual(legend, 'curve 0')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data')
+
+ # Active curve (5 elements)
+ xOut, yOut, legend, info, params = plot.getCurve("curve 1")
+ self.assertEqual(legend, 'curve 1')
+ self.assertEqual(info, 'anything')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, 2*x), 'curve 1 wrong y data')
+
+ def testGetImage(self):
+ """Plot.getImage and Plot.getActiveImage tests"""
+
+ plot = Plot(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ plot.addImage(((0, 1), (2, 3)), legend='image 0', replace=False)
+ plot.addImage(((0, 1), (2, 3)), legend='image 1', replace=False)
+
+ # Active image
+ active = plot.getActiveImage()
+ self.assertEqual(active.getLegend(), 'image 0')
+ image = plot.getImage()
+ self.assertEqual(image.getLegend(), 'image 0')
+
+ # No active image
+ plot.addImage(((0, 1), (2, 3)), legend='image 2', replace=False)
+ plot.setActiveImage(None)
+ active = plot.getActiveImage()
+ self.assertIsNone(active)
+ image = plot.getImage()
+ self.assertEqual(image.getLegend(), 'image 2')
+
+ # Active image
+ plot.setActiveImage('image 1')
+ active = plot.getActiveImage()
+ self.assertEqual(active.getLegend(), 'image 1')
+ image = plot.getImage()
+ self.assertEqual(image.getLegend(), 'image 1')
+
+ def testGetImageOldApi(self):
+ """Plot.getImage and Plot.getActiveImage old API tests"""
+
+ plot = Plot(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ image = numpy.arange(10).astype(numpy.float32)
+ image.shape = 5, 2
+
+ plot.addImage(image, legend='image 0', info=["Hi!"], replace=False)
+
+ # Active image
+ data, legend, info, something, params = plot.getActiveImage()
+ self.assertEqual(legend, 'image 0')
+ self.assertEqual(info, ["Hi!"])
+ self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
+
+ def testGetAllImages(self):
+ """Plot.getAllImages test"""
+
+ plot = Plot(backend='none')
+
+ # No image
+ images = plot.getAllImages()
+ self.assertEqual(len(images), 0)
+
+ # 2 images
+ data = numpy.arange(100).reshape(10, 10)
+ plot.addImage(data, legend='1', replace=False)
+ plot.addImage(data, origin=(10, 10), legend='2', replace=False)
+ images = plot.getAllImages(just_legend=True)
+ self.assertEqual(list(images), ['1', '2'])
+ images = plot.getAllImages(just_legend=False)
+ self.assertEqual(len(images), 2)
+ self.assertEqual(images[0].getLegend(), '1')
+ self.assertEqual(images[1].getLegend(), '2')
+
+
+class TestPlotAddScatter(unittest.TestCase):
+ """Test of plot addScatter"""
+
+ def testAddGetScatter(self):
+
+ plot = Plot(backend='none')
+
+ # No curve
+ scatter = plot._getItem(kind="scatter")
+ self.assertIsNone(scatter) # No curve
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+ plot._setActiveItem('scatter', 'scatter 0')
+
+ # Active scatter
+ active = plot._getActiveItem(kind='scatter')
+ self.assertEqual(active.getLegend(), 'scatter 0')
+
+ # check default values
+ self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
+ self.assertEqual(active.getSymbol(), "o")
+ self.assertAlmostEqual(active.getAlpha(), 1.0)
+
+ # modify parameters
+ active.setSymbolSize(20.5)
+ active.setSymbol("d")
+ active.setAlpha(0.777)
+
+ s0 = plot.getScatter("scatter 0")
+
+ self.assertAlmostEqual(s0.getSymbolSize(), 20.5)
+ self.assertEqual(s0.getSymbol(), "d")
+ self.assertAlmostEqual(s0.getAlpha(), 0.777)
+
+ scatter1 = plot._getItem(kind='scatter', legend='scatter 1')
+ self.assertEqual(scatter1.getLegend(), 'scatter 1')
+
+ def testGetAllScatters(self):
+ """Plot.getAllImages test"""
+
+ plot = Plot(backend='none')
+
+ scatters = plot._getItems(kind='scatter')
+ self.assertEqual(len(scatters), 0)
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+
+ scatters = plot._getItems(kind='scatter')
+ self.assertEqual(len(scatters), 3)
+ self.assertEqual(scatters[0].getLegend(), 'scatter 0')
+ self.assertEqual(scatters[2].getLegend(), 'scatter 2')
+
+ scatters = plot._getItems(kind='scatter', just_legend=True)
+ self.assertEqual(len(scatters), 3)
+ self.assertEqual(list(scatters), ['scatter 0', 'scatter 1', 'scatter 2'])
+
+
+class TestPlotHistogram(unittest.TestCase):
+ """Basic tests for histogram."""
+
+ def testEdges(self):
+ x = numpy.array([0, 1, 2])
+ edgesRight = numpy.array([0, 1, 2, 3])
+ edgesLeft = numpy.array([-1, 0, 1, 2])
+ edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
+
+ # testing x values for right
+ edges = _computeEdges(x, 'right')
+ numpy.testing.assert_array_equal(edges, edgesRight)
+
+ edges = _computeEdges(x, 'center')
+ numpy.testing.assert_array_equal(edges, edgesCenter)
+
+ edges = _computeEdges(x, 'left')
+ numpy.testing.assert_array_equal(edges, edgesLeft)
+
+ def testHistogramCurve(self):
+ y = numpy.array([3, 2, 5])
+ edges = numpy.array([0, 1, 2, 3])
+
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
+
+ y = numpy.array([-3, 2, 5, 0])
+ edges = numpy.array([-2, -1, 0, 1, 2])
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0]))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestPlot, TestPlotRanges, TestPlotGetCurveImage,
+ TestPlotHistogram, TestPlotAddScatter):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotInteraction.py b/silx/gui/plot/test/testPlotInteraction.py
new file mode 100644
index 0000000..25f57a9
--- /dev/null
+++ b/silx/gui/plot/test/testPlotInteraction.py
@@ -0,0 +1,167 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of plot interaction, through a PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+import unittest
+from silx.gui import qt
+from silx.gui.plot.test.testPlotWidget import _PlotWidgetTest
+
+
+class _SignalDump(object):
+ """Callable object that store passed arguments in a list"""
+
+ def __init__(self):
+ self._received = []
+
+ def __call__(self, *args):
+ self._received.append(args)
+
+ @property
+ def received(self):
+ """Return a shallow copy of the list of received arguments"""
+ return list(self._received)
+
+
+class TestSelectPolygon(_PlotWidgetTest):
+ """Test polygon selection interaction"""
+
+ def _interactionModeChanged(self, source):
+ """Check that source received in event is the correct one"""
+ self.assertEqual(source, self)
+
+ def _draw(self, polygon):
+ """Draw a polygon in the plot
+
+ :param polygon: List of points (x, y) of the polygon (not closed)
+ """
+ plot = self.plot.centralWidget()
+
+ dump = _SignalDump()
+ self.plot.sigPlotSignal.connect(dump)
+
+ for pos in polygon:
+ self.mouseMove(plot, pos=pos)
+ btn = qt.Qt.LeftButton if pos != polygon[-1] else qt.Qt.RightButton
+ self.mouseClick(plot, btn, pos=pos)
+
+ self.plot.sigPlotSignal.disconnect(dump)
+ return [args[0] for args in dump.received]
+
+ def test(self):
+ """Test draw polygons + events"""
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactionModeChanged)
+
+ self.plot.setInteractiveMode(
+ 'draw', shape='polygon', label='test', source=self)
+ interaction = self.plot.getInteractiveMode()
+
+ self.assertEqual(interaction['mode'], 'draw')
+ self.assertEqual(interaction['shape'], 'polygon')
+
+ self.plot.sigInteractiveModeChanged.disconnect(
+ self._interactionModeChanged)
+
+ plot = self.plot.centralWidget()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ # Star polygon
+ star = [(xCenter, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter),
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter - offset)]
+
+ # Draw while dumping signals
+ events = self._draw(star)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 6)
+
+ # Large square
+ largeSquare = [(xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter - offset),
+ (xCenter + offset, yCenter + offset),
+ (xCenter - offset, yCenter + offset)]
+
+ # Draw while dumping signals
+ events = self._draw(largeSquare)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 5)
+
+ # Rectangle too thin along X: Some points are ignored
+ thinRectX = [(xCenter, yCenter - offset),
+ (xCenter, yCenter + offset),
+ (xCenter + 1, yCenter + offset),
+ (xCenter + 1, yCenter - offset)]
+
+ # Draw while dumping signals
+ events = self._draw(thinRectX)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
+
+ # Rectangle too thin along Y: Some points are ignored
+ thinRectY = [(xCenter - offset, yCenter),
+ (xCenter + offset, yCenter),
+ (xCenter + offset, yCenter + 1),
+ (xCenter - offset, yCenter + 1)]
+
+ # Draw while dumping signals
+ events = self._draw(thinRectY)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestSelectPolygon,):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotTools.py b/silx/gui/plot/test/testPlotTools.py
new file mode 100644
index 0000000..1d5e148
--- /dev/null
+++ b/silx/gui/plot/test/testPlotTools.py
@@ -0,0 +1,203 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotTools"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import numpy
+import unittest
+
+from silx.test.utils import ParametricTestCase, TestLogging
+from silx.gui.test.utils import (
+ qWaitForWindowExposedAndActivate, TestCaseQt, getQToolButtonFromAction)
+from silx.gui import qt
+from silx.gui.plot import Plot2D, PlotWindow, PlotTools
+
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+def _tearDownDocTest(docTest):
+ """Tear down to use for test from docstring.
+
+ Checks that plot widget is displayed
+ """
+ plot = docTest.globs['plot']
+ qWaitForWindowExposedAndActivate(plot)
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot.close()
+ del plot
+
+# Disable doctest because of
+# "NameError: name 'numpy' is not defined"
+#
+# import doctest
+# positionInfoTestSuite = doctest.DocTestSuite(
+# PlotTools, tearDown=_tearDownDocTest,
+# optionflags=doctest.ELLIPSIS)
+# """Test suite of tests from PlotTools docstrings.
+#
+# Test PositionInfo and ProfileToolBar docstrings.
+# """
+
+
+class TestPositionInfo(TestCaseQt):
+ """Tests for PositionInfo widget."""
+
+ def setUp(self):
+ super(TestPositionInfo, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.mouseMove(self.plot, pos=(1, 1))
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ super(TestPositionInfo, self).tearDown()
+
+ def _test(self, positionWidget, converterNames, **kwargs):
+ """General test of PositionInfo.
+
+ - Add it to a toolbar and
+ - Move mouse around the center of the PlotWindow.
+ """
+ toolBar = qt.QToolBar()
+ self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar)
+
+ toolBar.addWidget(positionWidget)
+
+ converters = positionWidget.getConverters()
+ self.assertEqual(len(converters), len(converterNames))
+ for index, name in enumerate(converterNames):
+ self.assertEqual(converters[index][0], name)
+
+ with TestLogging(PlotTools.__name__, **kwargs):
+ # Move mouse to center
+ self.mouseMove(self.plot)
+ self.mouseMove(self.plot, pos=(1, 1))
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ def testDefaultConverters(self):
+ """Test PositionInfo with default converters"""
+ positionWidget = PlotTools.PositionInfo(plot=self.plot)
+ self._test(positionWidget, ('X', 'Y'))
+
+ def testCustomConverters(self):
+ """Test PositionInfo with custom converters"""
+ converters = [
+ ('Coords', lambda x, y: (int(x), int(y))),
+ ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)),
+ ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))
+ ]
+ positionWidget = PlotTools.PositionInfo(plot=self.plot,
+ converters=converters)
+ self._test(positionWidget, ('Coords', 'Radius', 'Angle'))
+
+ def testFailingConverters(self):
+ """Test PositionInfo with failing custom converters"""
+ def raiseException(x, y):
+ raise RuntimeError()
+
+ positionWidget = PlotTools.PositionInfo(
+ plot=self.plot,
+ converters=[('Exception', raiseException)])
+ self._test(positionWidget, ['Exception'], error=2)
+
+
+class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
+ """Tests for ProfileToolBar widget."""
+
+ def setUp(self):
+ super(TestPixelIntensitiesHisto, self).setUp()
+ self.image = numpy.random.rand(100, 100)
+ self.plotImage = Plot2D()
+ self.plotImage.getIntensityHistogramAction().setVisible(True)
+
+ def tearDown(self):
+ del self.plotImage
+ super(TestPixelIntensitiesHisto, self).tearDown()
+
+ def testShowAndHide(self):
+ """Simple test that the plot is showing and hiding when activating the
+ action"""
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertTrue(histoAction.getHistogramPlotWidget().isVisible())
+
+ # test the pixel intensity diagram is hiding
+ self.qapp.setActiveWindow(self.plotImage)
+ self.qapp.processEvents()
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertFalse(histoAction.getHistogramPlotWidget().isVisible())
+
+ def testImageFormatInput(self):
+ """Test multiple type as image input"""
+ typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
+ numpy.float32, numpy.float64]
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+ button = getQToolButtonFromAction(
+ self.plotImage.getIntensityHistogramAction())
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ for typeToTest in typesToTest:
+ with self.subTest(typeToTest=typeToTest):
+ self.plotImage.addImage(self.image.astype(typeToTest),
+ origin=(0, 0), legend='sino')
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ # test_suite.addTest(positionInfoTestSuite)
+ for testClass in (TestPositionInfo, TestPixelIntensitiesHisto):
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
+ testClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
new file mode 100644
index 0000000..2de18a8
--- /dev/null
+++ b/silx/gui/plot/test/testPlotWidget.py
@@ -0,0 +1,967 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+
+import numpy
+
+from silx.test.utils import ParametricTestCase
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+
+
+SIZE = 1024
+"""Size of the test image"""
+
+DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE)
+"""Image data set"""
+
+
+class _PlotWidgetTest(TestCaseQt):
+ """Base class for tests of PlotWidget, not a TestCase in itself.
+
+ plot attribute is the PlotWidget created for the test.
+ """
+
+ def setUp(self):
+ super(_PlotWidgetTest, self).setUp()
+ self.plot = PlotWidget()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(_PlotWidgetTest, self).tearDown()
+
+
+class TestPlotWidget(_PlotWidgetTest, ParametricTestCase):
+ """Basic tests for PlotWidget"""
+
+ def testShow(self):
+ """Most basic test"""
+ pass
+
+ def testSetTitleLabels(self):
+ """Set title and axes labels"""
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ self.plot.setGraphTitle(title)
+ self.plot.setGraphXLabel(xlabel)
+ self.plot.setGraphYLabel(ylabel)
+ self.qapp.processEvents()
+
+ self.assertEqual(self.plot.getGraphTitle(), title)
+ self.assertEqual(self.plot.getGraphXLabel(), xlabel)
+ self.assertEqual(self.plot.getGraphYLabel(), ylabel)
+
+ def testChangeLimitsWithAspectRatio(self):
+ def checkLimits(expectedXLim=None, expectedYLim=None,
+ expectedRatio=None):
+ xlim = self.plot.getGraphXLimits()
+ ylim = self.plot.getGraphYLimits()
+ ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ if expectedXLim is not None:
+ self.assertEqual(expectedXLim, xlim)
+
+ if expectedYLim is not None:
+ self.assertEqual(expectedYLim, ylim)
+
+ if expectedRatio is not None:
+ self.assertTrue(
+ numpy.allclose(expectedRatio, ratio, atol=0.01))
+
+ self.plot.setKeepDataAspectRatio()
+ self.qapp.processEvents()
+ xlim = self.plot.getGraphXLimits()
+ ylim = self.plot.getGraphYLimits()
+ defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ self.plot.setGraphXLimits(1., 10.)
+ checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+
+ self.plot.setGraphYLimits(1., 10.)
+ checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+
+
+class TestPlotImage(_PlotWidgetTest, ParametricTestCase):
+ """Basic tests for addImage"""
+
+ def setUp(self):
+ super(TestPlotImage, self).setUp()
+
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+
+ def testPlotColormapTemperature(self):
+ self.plot.setGraphTitle('Temp. Linear')
+
+ colormap = {'name': 'temperature', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapGray(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Gray Linear')
+
+ colormap = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapTemperatureLog(self):
+ self.plot.setGraphTitle('Temp. Log')
+
+ colormap = {'name': 'temperature', 'normalization': 'log',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotRgbRgba(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('RGB + RGBA')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb",
+ origin=(0, 0), scale=(10, 10),
+ replace=False, resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba",
+ origin=(5, 5), scale=(10, 10),
+ replace=False, resetzoom=False)
+
+ self.plot.resetZoom()
+
+ def testPlotColormapCustom(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Custom colormap')
+
+ colormap = {'name': None, 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0,
+ 'colors': ((0., 0., 0.), (1., 0., 0.),
+ (0., 1., 0.), (0., 0., 1.))}
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap,
+ replace=False, resetzoom=False)
+
+ colormap = {'name': None, 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0,
+ 'colors': numpy.array(
+ ((0, 0, 0, 0), (0, 0, 0, 128),
+ (128, 128, 128, 128), (255, 255, 255, 255)),
+ dtype=numpy.uint8)}
+ self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap,
+ origin=(DATA_2D.shape[0], 0),
+ replace=False, resetzoom=False)
+ self.plot.resetZoom()
+
+ def testImageOriginScale(self):
+ """Test of image with different origin and scale"""
+ self.plot.setGraphTitle('origin and scale')
+
+ tests = [ # (origin, scale)
+ ((10, 20), (1, 1)),
+ ((10, 20), (-1, -1)),
+ ((-10, 20), (2, 1)),
+ ((10, -20), (-1, -2)),
+ (100, 2),
+ (-100, (1, 1)),
+ ((10, 20), 2),
+ ]
+
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.plot.addImage(DATA_2D, origin=origin, scale=scale)
+
+ try:
+ ox, oy = origin
+ except TypeError:
+ ox, oy = origin, origin
+ try:
+ sx, sy = scale
+ except TypeError:
+ sx, sy = scale, scale
+ xbounds = ox, ox + DATA_2D.shape[1] * sx
+ ybounds = oy, oy + DATA_2D.shape[0] * sy
+
+ # Check limits without aspect ratio
+ xmin, xmax = self.plot.getGraphXLimits()
+ ymin, ymax = self.plot.getGraphYLimits()
+ self.assertEqual(xmin, min(xbounds))
+ self.assertEqual(xmax, max(xbounds))
+ self.assertEqual(ymin, min(ybounds))
+ self.assertEqual(ymax, max(ybounds))
+
+ # Check limits with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ xmin, xmax = self.plot.getGraphXLimits()
+ ymin, ymax = self.plot.getGraphYLimits()
+ self.assertTrue(xmin <= min(xbounds))
+ self.assertTrue(xmax >= max(xbounds))
+ self.assertTrue(ymin <= min(ybounds))
+ self.assertTrue(ymax >= max(ybounds))
+
+ self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio
+ self.plot.clear()
+ self.plot.resetZoom()
+
+
+class TestPlotCurve(_PlotWidgetTest):
+ """Basic tests for addCurve."""
+
+ # Test data sets
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def setUp(self):
+ super(TestPlotCurve, self).setUp()
+ self.plot.setGraphTitle('Curve')
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+
+ self.plot.setActiveCurveHandling(False)
+
+ def testPlotCurveColorFloat(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColorByte(self):
+ color = numpy.array(255 * numpy.random.random(3 * 1000),
+ dtype=numpy.uint8).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColors(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=color, linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+
+class TestPlotMarker(_PlotWidgetTest):
+ """Basic tests for add*Marker"""
+
+ def setUp(self):
+ super(TestPlotMarker, self).setUp()
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+
+ self.plot.setXAxisAutoScale(False)
+ self.plot.setYAxisAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotMarkerX(self):
+ self.plot.setGraphTitle('Markers X')
+
+ markers = [
+ (10., 'blue', False, False),
+ (20., 'red', False, False),
+ (40., 'green', True, False),
+ (60., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for x, color, select, drag in markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerY(self):
+ self.plot.setGraphTitle('Markers Y')
+
+ markers = [
+ (-50., 'blue', False, False),
+ (-30., 'red', False, False),
+ (0., 'green', True, False),
+ (10., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for y, color, select, drag in markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPt(self):
+ self.plot.setGraphTitle('Markers Pt')
+
+ markers = [
+ (10., -50., 'blue', False, False),
+ (40., -30., 'red', False, False),
+ (50., 0., 'green', True, False),
+ (50., 20., 'gray', True, True),
+ (70., 50., 'black', False, True),
+ ]
+ for x, y, color, select, drag in markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerWithoutLegend(self):
+ self.plot.setGraphTitle('Markers without legend')
+ self.plot.setYAxisInverted(True)
+
+ # Markers without legend
+ self.plot.addMarker(10, 10)
+ self.plot.addMarker(10, 20)
+ self.plot.addMarker(40, 50, text='test', symbol=None)
+ self.plot.addMarker(40, 50, text='test', symbol='+')
+ self.plot.addXMarker(25)
+ self.plot.addXMarker(35)
+ self.plot.addXMarker(45, text='test')
+ self.plot.addYMarker(55)
+ self.plot.addYMarker(65)
+ self.plot.addYMarker(75, text='test')
+
+ self.plot.resetZoom()
+
+
+# TestPlotItem ################################################################
+
+class TestPlotItem(_PlotWidgetTest):
+ """Basic tests for addItem."""
+
+ # Polygon coordinates and color
+ polygons = [ # legend, x coords, y coords, color
+ ('triangle', numpy.array((10, 30, 50)),
+ numpy.array((55, 70, 55)), 'red'),
+ ('square', numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)), 'green'),
+ ('star', numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)), 'blue'),
+ ]
+
+ # Rectangle coordinantes and color
+ rectangles = [ # legend, x coords, y coords, color
+ ('square 1', numpy.array((1., 10.)),
+ numpy.array((1., 10.)), 'red'),
+ ('square 2', numpy.array((10., 20.)),
+ numpy.array((10., 20.)), 'green'),
+ ('square 3', numpy.array((20., 30.)),
+ numpy.array((20., 30.)), 'blue'),
+ ('rect 1', numpy.array((1., 30.)),
+ numpy.array((35., 40.)), 'black'),
+ ('line h', numpy.array((1., 30.)),
+ numpy.array((45., 45.)), 'darkRed'),
+ ]
+
+ def setUp(self):
+ super(TestPlotItem, self).setUp()
+
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+ self.plot.setXAxisAutoScale(False)
+ self.plot.setYAxisAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotItemPolygonFill(self):
+ self.plot.setGraphTitle('Item Fill')
+
+ for legend, xList, yList, color in self.polygons:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="polygon", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemPolygonNoFill(self):
+ self.plot.setGraphTitle('Item No Fill')
+
+ for legend, xList, yList, color in self.polygons:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="polygon", fill=False, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleFill(self):
+ self.plot.setGraphTitle('Rectangle Fill')
+
+ for legend, xList, yList, color in self.rectangles:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleNoFill(self):
+ self.plot.setGraphTitle('Rectangle No Fill')
+
+ for legend, xList, yList, color in self.rectangles:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=False, color=color)
+ self.plot.resetZoom()
+
+
+class TestPlotActiveCurveImage(_PlotWidgetTest):
+ """Basic tests for active image handling"""
+
+ def testActiveCurveAndLabels(self):
+ # Active curve handling off, no label change
+ self.plot.setActiveCurveHandling(False)
+ self.plot.setGraphXLabel('XLabel')
+ self.plot.setGraphYLabel('YLabel')
+ self.plot.addCurve((1, 2), (1, 2))
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ # Active curve handling on, label changes
+ self.plot.setActiveCurveHandling(True)
+ self.plot.setGraphXLabel('XLabel')
+ self.plot.setGraphYLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addCurve((1, 2), (1, 2), legend='1',
+ xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addCurve((1, 2), (2, 3), legend='2')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveCurve('2')
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ self.plot.setActiveCurve('1')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ def testActiveImageAndLabels(self):
+ # Active image handling always on, no API for toggling it
+ self.plot.setGraphXLabel('XLabel')
+ self.plot.setGraphYLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False,
+ legend='1', xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), replace=False,
+ legend='2')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveImage('2')
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+ self.plot.setActiveImage('1')
+ self.assertEqual(self.plot.getGraphXLabel(), 'x1')
+ self.assertEqual(self.plot.getGraphYLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getGraphXLabel(), 'XLabel')
+ self.assertEqual(self.plot.getGraphYLabel(), 'YLabel')
+
+
+##############################################################################
+# Log
+##############################################################################
+
+class TestPlotEmptyLog(_PlotWidgetTest):
+ """Basic tests for log plot"""
+ def testEmptyPlotTitleLabelsLog(self):
+ self.plot.setGraphTitle('Empty Log Log')
+ self.plot.setGraphXLabel('X')
+ self.plot.setGraphYLabel('Y')
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+ self.plot.resetZoom()
+
+
+class TestPlotCurveLog(_PlotWidgetTest, ParametricTestCase):
+ """Basic tests for addCurve with log scale axes"""
+
+ # Test data
+ xData = numpy.arange(1000) + 1
+ yData = xData ** 2
+
+ def _setLabels(self):
+ self.plot.setGraphXLabel('X')
+ self.plot.setGraphYLabel('X * X')
+
+ def testPlotCurveLogX(self):
+ self._setLabels()
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setGraphTitle('Curve X: Log Y: Linear')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogY(self):
+ self._setLabels()
+ self.plot.setYAxisLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogXY(self):
+ self._setLabels()
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Log Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveErrorLogXY(self):
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+
+ # Every second error leads to negative number
+ errors = numpy.ones_like(self.xData)
+ errors[::2] = self.xData[::2] + 1
+
+ tests = [ # name, xerror, yerror
+ ('xerror=3', 3, None),
+ ('xerror=N array', errors, None),
+ ('xerror=Nx1 array', errors.reshape(len(errors), 1), None),
+ ('xerror=2xN array', numpy.array((errors, errors)), None),
+ ('yerror=6', None, 6),
+ ('yerror=N array', None, errors ** 2),
+ ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)),
+ ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2),
+ ]
+
+ for name, xError, yError in tests:
+ with self.subTest(name):
+ self.plot.setGraphTitle(name)
+ self.plot.addCurve(self.xData, self.yData,
+ legend=name,
+ xerror=xError, yerror=yError,
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ self.qapp.processEvents()
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ def testPlotCurveToggleLog(self):
+ """Add a curve with negative data and toggle log axis"""
+ arange = numpy.arange(1000) + 1
+ tests = [ # name, xData, yData
+ ('x>0, some negative y', arange, arange - 500),
+ ('x>0, y<0', arange, -arange),
+ ('some negative x, y>0', arange - 500, arange),
+ ('x<0, y>0', -arange, arange),
+ ('some negative x and y', arange - 500, arange - 500),
+ ('x<0, y<0', -arange, -arange),
+ ]
+
+ for name, xData, yData in tests:
+ with self.subTest(name):
+ self.plot.addCurve(xData, yData, resetzoom=True)
+ self.qapp.processEvents()
+
+ # no log axis
+ xLim = self.plot.getGraphXLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getGraphYLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ # x axis log
+ self.plot.setXAxisLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getGraphXLimits()
+ yLim = self.plot.getGraphYLimits()
+ positives = xData > 0
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertEqual(
+ yLim, (min(yData[positives]), max(yData[positives])))
+ else: # No positive x in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # x axis and y axis log
+ self.plot.setYAxisLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getGraphXLimits()
+ yLim = self.plot.getGraphYLimits()
+ positives = numpy.logical_and(xData > 0, yData > 0)
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive x and y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # y axis log
+ self.plot.setXAxisLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getGraphXLimits()
+ yLim = self.plot.getGraphYLimits()
+ positives = yData > 0
+ if numpy.any(positives):
+ self.assertEqual(
+ xLim, (min(xData[positives]), max(xData[positives])))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # no log axis
+ self.plot.setYAxisLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getGraphXLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getGraphYLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+
+class TestPlotImageLog(_PlotWidgetTest):
+ """Basic tests for addImage with log scale axes."""
+
+ def setUp(self):
+ super(TestPlotImageLog, self).setUp()
+
+ self.plot.setGraphXLabel('Columns')
+ self.plot.setGraphYLabel('Rows')
+
+ def testPlotColormapGrayLogX(self):
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Linear')
+
+ colormap = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ replace=False, resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogY(self):
+ self.plot.setYAxisLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Linear Y: Log')
+
+ colormap = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ replace=False, resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogXY(self):
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Log')
+
+ colormap = {'name': 'gray', 'normalization': 'linear',
+ 'autoscale': True, 'vmin': 0.0, 'vmax': 1.0}
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ replace=False, resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotRgbRgbaLogXY(self):
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+ self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb",
+ origin=(1, 1), scale=(10, 10),
+ replace=False, resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba",
+ origin=(5., 5.), scale=(10., 10.),
+ replace=False, resetzoom=False)
+ self.plot.resetZoom()
+
+
+class TestPlotMarkerLog(_PlotWidgetTest):
+ """Basic tests for markers on log scales"""
+
+ # Test marker parameters
+ markers = [ # x, y, color, selectable, draggable
+ (10., 10., 'blue', False, False),
+ (20., 20., 'red', False, False),
+ (40., 100., 'green', True, False),
+ (40., 500., 'gray', True, True),
+ (60., 800., 'black', False, True),
+ ]
+
+ def setUp(self):
+ super(TestPlotMarkerLog, self).setUp()
+
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+ self.plot.setXAxisAutoScale(False)
+ self.plot.setYAxisAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(1., 100., 1., 1000.)
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+
+ def testPlotMarkerXLog(self):
+ self.plot.setGraphTitle('Markers X, Log axes')
+
+ for x, _, color, select, drag in self.markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerYLog(self):
+ self.plot.setGraphTitle('Markers Y, Log axes')
+
+ for _, y, color, select, drag in self.markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPtLog(self):
+ self.plot.setGraphTitle('Markers Pt, Log axes')
+
+ for x, y, color, select, drag in self.markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+
+class TestPlotItemLog(_PlotWidgetTest):
+ """Basic tests for items with log scale axes"""
+
+ # Polygon coordinates and color
+ polygons = [ # legend, x coords, y coords, color
+ ('triangle', numpy.array((10, 30, 50)),
+ numpy.array((55, 70, 55)), 'red'),
+ ('square', numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)), 'green'),
+ ('star', numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)), 'blue'),
+ ]
+
+ # Rectangle coordinantes and color
+ rectangles = [ # legend, x coords, y coords, color
+ ('square 1', numpy.array((1., 10.)),
+ numpy.array((1., 10.)), 'red'),
+ ('square 2', numpy.array((10., 20.)),
+ numpy.array((10., 20.)), 'green'),
+ ('square 3', numpy.array((20., 30.)),
+ numpy.array((20., 30.)), 'blue'),
+ ('rect 1', numpy.array((1., 30.)),
+ numpy.array((35., 40.)), 'black'),
+ ('line h', numpy.array((1., 30.)),
+ numpy.array((45., 45.)), 'darkRed'),
+ ]
+
+ def setUp(self):
+ super(TestPlotItemLog, self).setUp()
+
+ self.plot.setGraphYLabel('Rows')
+ self.plot.setGraphXLabel('Columns')
+ self.plot.setXAxisAutoScale(False)
+ self.plot.setYAxisAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(1., 100., 1., 100.)
+ self.plot.setXAxisLogarithmic(True)
+ self.plot.setYAxisLogarithmic(True)
+
+ def testPlotItemPolygonLogFill(self):
+ self.plot.setGraphTitle('Item Fill Log')
+
+ for legend, xList, yList, color in self.polygons:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="polygon", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemPolygonLogNoFill(self):
+ self.plot.setGraphTitle('Item No Fill Log')
+
+ for legend, xList, yList, color in self.polygons:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="polygon", fill=False, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleLogFill(self):
+ self.plot.setGraphTitle('Rectangle Fill Log')
+
+ for legend, xList, yList, color in self.rectangles:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleLogNoFill(self):
+ self.plot.setGraphTitle('Rectangle No Fill Log')
+
+ for legend, xList, yList, color in self.rectangles:
+ self.plot.addItem(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=False, color=color)
+ self.plot.resetZoom()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWidget))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotImage))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotCurve))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotMarker))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotItem))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotEmptyLog))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotCurveLog))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotImageLog))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotMarkerLog))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotItemLog))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py
new file mode 100644
index 0000000..5afd53a
--- /dev/null
+++ b/silx/gui/plot/test/testPlotWindow.py
@@ -0,0 +1,138 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import doctest
+import unittest
+
+from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction
+
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+
+
+# Test of the docstrings #
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+def _tearDownQt(docTest):
+ """Tear down to use for test from docstring.
+
+ Checks that plt widget is displayed
+ """
+ _qapp.processEvents()
+ for obj in docTest.globs.values():
+ if isinstance(obj, PlotWindow):
+ # Commented out as it takes too long
+ # qWaitForWindowExposedAndActivate(obj)
+ obj.setAttribute(qt.Qt.WA_DeleteOnClose)
+ obj.close()
+ del obj
+
+
+plotWindowDocTestSuite = doctest.DocTestSuite('silx.gui.plot.PlotWindow',
+ tearDown=_tearDownQt)
+"""Test suite of tests from the module's docstrings."""
+
+
+class TestPlotWindow(TestCaseQt):
+ """Base class for tests of PlotWindow."""
+
+ def setUp(self):
+ super(TestPlotWindow, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotWindow, self).tearDown()
+
+ def testActions(self):
+ """Test the actions QToolButtons"""
+ self.plot.setLimits(1, 100, 1, 100)
+
+ checkList = [ # QAction, Plot state getter
+ (self.plot.xAxisAutoScaleAction, self.plot.isXAxisAutoScale),
+ (self.plot.yAxisAutoScaleAction, self.plot.isYAxisAutoScale),
+ (self.plot.xAxisLogarithmicAction, self.plot.isXAxisLogarithmic),
+ (self.plot.yAxisLogarithmicAction, self.plot.isYAxisLogarithmic),
+ (self.plot.gridAction, self.plot.getGraphGrid),
+ ]
+
+ for action, getter in checkList:
+ self.mouseMove(self.plot)
+ initialState = getter()
+ toolButton = getQToolButtonFromAction(action)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertNotEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ # Trigger a zoom reset
+ self.mouseMove(self.plot)
+ resetZoomAction = self.plot.resetZoomAction
+ toolButton = getQToolButtonFromAction(resetZoomAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ def testToolAspectRatio(self):
+ self.plot.toolBar()
+ self.plot.keepDataAspectRatioButton.keepDataAspectRatio()
+ self.assertTrue(self.plot.isKeepDataAspectRatio())
+ self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio()
+ self.assertFalse(self.plot.isKeepDataAspectRatio())
+
+ def testToolYAxisOrigin(self):
+ self.plot.toolBar()
+ self.plot.yAxisInvertedButton.setYAxisUpward()
+ self.assertFalse(self.plot.isYAxisInverted())
+ self.plot.yAxisInvertedButton.setYAxisDownward()
+ self.assertTrue(self.plot.isYAxisInverted())
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(plotWindowDocTestSuite)
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWindow))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testProfile.py b/silx/gui/plot/test/testProfile.py
new file mode 100644
index 0000000..43d3329
--- /dev/null
+++ b/silx/gui/plot/test/testProfile.py
@@ -0,0 +1,183 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for Profile"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "23/02/2017"
+
+import numpy
+import unittest
+
+from silx.test.utils import ParametricTestCase
+from silx.gui.test.utils import (
+ TestCaseQt, getQToolButtonFromAction)
+from silx.gui import qt
+from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile
+from silx.gui.plot.StackView import StackView
+
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+class TestProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ProfileToolBar widget."""
+
+ def setUp(self):
+ super(TestProfileToolBar, self).setUp()
+ profileWindow = PlotWindow()
+ self.plot = PlotWindow()
+ self.toolBar = Profile.ProfileToolBar(
+ plot=self.plot, profileWindow=profileWindow)
+ self.plot.addToolBar(self.toolBar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ profileWindow.show()
+ self.qWaitForWindowExposed(profileWindow)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.toolBar
+
+ super(TestProfileToolBar, self).tearDown()
+
+ def testAlignedProfile(self):
+ """Test horizontal and vertical profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ for action in (self.toolBar.hLineAction, self.toolBar.vLineAction):
+ with self.subTest(mode=action.text()):
+ # Trigger tool button for mode
+ toolButton = getQToolButtonFromAction(action)
+ self.assertIsNot(toolButton, None)
+ self.mouseMove(toolButton)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # Without image
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ # with image
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(widget, pos=pos2)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.mouseMove(widget)
+ self.mouseClick(widget, qt.Qt.LeftButton)
+
+ def testDiagonalProfile(self):
+ """Test diagonal profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ # Trigger tool button for diagonal profile mode
+ toolButton = getQToolButtonFromAction(self.toolBar.lineAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseMove(toolButton)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ for image in (False, True):
+ with self.subTest(image=image):
+ if image:
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
+
+ self.mouseMove(widget, pos=pos1)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(widget, pos=pos2)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.plot.clear()
+
+
+class TestGetProfilePlot(TestCaseQt):
+
+ def testProfile1D(self):
+ plot = Plot2D()
+ plot.show()
+ self.qWaitForWindowExposed(plot)
+ plot.addImage([[0, 1], [2, 3]])
+ self.assertIsInstance(plot.getProfileToolbar().getProfileMainWindow(),
+ qt.QMainWindow)
+ self.assertIsInstance(plot.getProfilePlot(),
+ Plot1D)
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot.close()
+ del plot
+
+ def testProfile2D(self):
+ """Test that the profile plot associated to a stack view is either a
+ Plot1D or a plot 2D instance."""
+ plot = StackView()
+ plot.show()
+ self.qWaitForWindowExposed(plot)
+
+ plot.setStack(numpy.array([[[0, 1], [2, 3]],
+ [[4, 5], [6, 7]]]))
+
+ self.assertIsInstance(plot.getProfileToolbar().getProfileMainWindow(),
+ qt.QMainWindow)
+
+ # plot.getProfileToolbar().profile3dAction.computeProfileIn2D() # default
+
+ self.assertIsInstance(plot.getProfileToolbar().getProfilePlot(),
+ Plot2D)
+ plot.getProfileToolbar().profile3dAction.computeProfileIn1D()
+ self.assertIsInstance(plot.getProfileToolbar().getProfilePlot(),
+ Plot1D)
+
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot.close()
+ del plot
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ # test_suite.addTest(positionInfoTestSuite)
+ for testClass in (TestProfileToolBar, TestGetProfilePlot):
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
+ testClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py
new file mode 100644
index 0000000..8b5f2ad
--- /dev/null
+++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -0,0 +1,313 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/07/2017"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir, ParametricTestCase
+from silx.gui.test.utils import TestCaseQt, getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
+
+try:
+ import fabio
+except ImportError:
+ fabio = None
+
+
+logging.basicConfig()
+_logger = logging.getLogger(__name__)
+
+
+class TestScatterMaskToolsWidget(TestCaseQt, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def setUp(self):
+ super(TestScatterMaskToolsWidget, self).setUp()
+ self.plot = PlotWindow()
+
+ self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
+ plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ super(TestScatterMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.centralWidget()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=pos0)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.mouseMove(plot, pos=pos1)
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.centralWidget()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ btn = qt.Qt.LeftButton if pos != star[-1] else qt.Qt.RightButton
+ self.mouseClick(plot, btn, pos=pos)
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.centralWidget()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=star[0])
+ self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.mouseRelease(
+ plot, qt.Qt.LeftButton, pos=star[-1])
+
+ def testWithAScatter(self):
+ """Plot with a Scatter: test MaskToolsWidget interactions"""
+
+ # Add and remove a scatter (this should enable/disable GUI + change mask)
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=numpy.arange(256),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(10)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=25 * (numpy.arange(256) % 10),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveCsv(self):
+ self.__loadSave("csv")
+
+ def testSigMaskChangedEmitted(self):
+ self.qapp.processEvents()
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.ones((1000,)),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestScatterMaskToolsWidget,):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py
new file mode 100644
index 0000000..69584cd
--- /dev/null
+++ b/silx/gui/plot/test/testStackView.py
@@ -0,0 +1,209 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for StackView"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/03/2017"
+
+
+import unittest
+import numpy
+
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import StackView
+from silx.gui.plot.StackView import StackViewMainWindow
+
+from silx.utils.array_like import ListOfImages
+
+
+# Makes sure a QApplication exists
+_qapp = qt.QApplication.instance() or qt.QApplication([])
+
+
+class TestStackView(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackView, self).setUp()
+ self.stackview = StackView()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackView, self).tearDown()
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ # my_orig_stack, params = self.stackview.getStack()
+ my_trans_stack, params = self.stackview.getCurrentView()
+
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
+
+ def testSetStackListOfImages(self):
+ loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
+
+ self.stackview.setStack(loi)
+ my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_orig_stack))
+ self.assertIsInstance(my_trans_stack, numpy.ndarray)
+
+ self.stackview.setStack(loi, perspective=2)
+ my_orig_stack, params = self.stackview.getStack(copy=False)
+ my_trans_stack, params = self.stackview.getCurrentView(copy=False)
+ # getStack(copy=False) must return the object set in setStack
+ self.assertIs(my_orig_stack, loi)
+ # getCurrentView(copy=False) returns a ListOfImages whose .images
+ # attr is the original data
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]))
+ self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack),
+ numpy.transpose(self.mystack, axes=(2, 0, 1))))
+ self.assertIsInstance(my_trans_stack,
+ ListOfImages) # returnNumpyArray=False by default in getStack
+ self.assertIs(my_trans_stack.images, loi)
+
+ def testPerspective(self):
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
+ self.assertEqual(self.stackview._perspective, 0,
+ "Default perspective is not 0 (dim1-dim2).")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.assertEqual(self.stackview._perspective, 1,
+ "Plane selection combobox not updating perspective")
+
+ self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
+ self.assertEqual(self.stackview._perspective, 0,
+ "Default perspective not restored in setStack.")
+
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
+ self.assertEqual(self.stackview._perspective, 2,
+ "Perspective not set in setStack(..., perspective=2).")
+
+ def testTitle(self):
+ """Test that the plot title contains the proper Z information"""
+ self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=2")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=-10")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=10")
+
+ self.stackview._StackView__planeSelection.setPerspective(2)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=3.14")
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=6.28")
+
+
+class TestStackViewMainWindow(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackViewMainWindow, self).setUp()
+ self.stackview = StackViewMainWindow()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackViewMainWindow, self).tearDown()
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ my_trans_stack, params = self.stackview.getCurrentView()
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestStackView))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestStackViewMainWindow))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/Plot3DActions.py b/silx/gui/plot3d/Plot3DActions.py
new file mode 100644
index 0000000..2ae2750
--- /dev/null
+++ b/silx/gui/plot3d/Plot3DActions.py
@@ -0,0 +1,362 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides QAction that can be attached to a plot3DWidget."""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+import logging
+import os
+import weakref
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot.PlotActions import PrintAction as _PrintAction
+from silx.gui.icons import getQIcon
+from .utils import mng
+from .._utils import convertQImageToArray
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Plot3DAction(qt.QAction):
+ """QAction associated to a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param Plot3DWidget plot3d: Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(Plot3DAction, self).__init__(parent)
+ self._plot3d = None
+ self.setPlot3DWidget(plot3d)
+
+ def setPlot3DWidget(self, widget):
+ """Set the Plot3DWidget this action is associated with
+
+ :param Plot3DWidget widget: The Plot3DWidget to use
+ """
+ self._plot3d = None if widget is None else weakref.ref(widget)
+
+ def getPlot3DWidget(self):
+ """Return the Plot3DWidget associated to this action.
+
+ If no widget is associated, it returns None.
+
+ :rtype: qt.QWidget
+ """
+ return None if self._plot3d is None else self._plot3d()
+
+
+class CopyAction(Plot3DAction):
+ """QAction to provide copy of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param Plot3DWidget plot3d: Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(CopyAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('edit-copy'))
+ self.setText('Copy')
+ self.setToolTip('Copy a snapshot of the 3D scene to the clipboard')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot copy widget, no associated Plot3DWidget')
+ else:
+ image = plot3d.grabGL()
+ qt.QApplication.clipboard().setImage(image)
+
+
+class SaveAction(Plot3DAction):
+ """QAction to provide save snapshot of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param Plot3DWidget plot3d: Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(SaveAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-save'))
+ self.setText('Save...')
+ self.setToolTip('Save a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot save widget, no associated Plot3DWidget')
+ else:
+ dialog = qt.QFileDialog(self.parent())
+ dialog.setWindowTitle('Save snapshot as')
+ dialog.setModal(True)
+ dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)',
+ 'Plot3D Snapshot JPEG (*.jpg)'))
+
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+
+ if not dialog.exec_():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ image = plot3d.grabGL()
+ if not image.save(filename):
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save snapshot as',
+ 'Failed to save snapshot')
+
+
+class PrintAction(Plot3DAction):
+ """QAction to provide printing of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param Plot3DWidget plot3d: Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PrintAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-print'))
+ self.setText('Print...')
+ self.setToolTip('Print a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def getPrinter(self):
+ """Return the QPrinter instance used for printing.
+
+ :rtype: qt.QPrinter
+ """
+ # TODO This is a hack to sync with silx plot PrintAction
+ # This needs to be centralized
+ if _PrintAction._printer is None:
+ _PrintAction._printer = qt.QPrinter()
+ return _PrintAction._printer
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot print widget, no associated Plot3DWidget')
+ else:
+ printer = self.getPrinter()
+ dialog = qt.QPrintDialog(printer, plot3d)
+ dialog.setWindowTitle('Print Plot3D snapshot')
+ if not dialog.exec_():
+ return
+
+ image = plot3d.grabGL()
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(printer):
+ return
+
+ if (printer.pageRect().width() < image.width() or
+ printer.pageRect().height() < image.height()):
+ # Downscale to page
+ xScale = printer.pageRect().width() / image.width()
+ yScale = printer.pageRect().height() / image.height()
+ scale = min(xScale, yScale)
+ else:
+ scale = 1.
+
+ rect = qt.QRectF(0,
+ 0,
+ scale * image.width(),
+ scale * image.height())
+ painter.drawImage(rect, image)
+ painter.end()
+
+
+class VideoAction(Plot3DAction):
+ """This action triggers the recording of a video of the scene.
+
+ The scene is rotated 360 degrees around a vertical axis.
+
+ :param parent: Action parent see :class:`QAction`.
+ """
+
+ PNG_SERIE_FILTER = 'Serie of PNG files (*.png)'
+ MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)'
+
+ def __init__(self, parent, plot3d=None):
+ super(VideoAction, self).__init__(parent, plot3d)
+ self.setText('Record video..')
+ self.setIcon(getQIcon('camera'))
+ self.setToolTip(
+ 'Record a video of a 360 degrees rotation of the 3D scene.')
+ self.setCheckable(False)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ """Action triggered callback"""
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.warning(
+ 'Ignoring action triggered without Plot3DWidget set')
+ return
+
+ dialog = qt.QFileDialog(parent=plot3d)
+ dialog.setWindowTitle('Save video as...')
+ dialog.setModal(True)
+ dialog.setNameFilters([self.PNG_SERIE_FILTER,
+ self.MNG_FILTER])
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ if not dialog.exec_():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ nbFrames = int(4. * 25) # 4 seconds, 25 fps
+
+ if nameFilter == self.PNG_SERIE_FILTER:
+ self._saveAsPNGSerie(filename, nbFrames)
+ elif nameFilter == self.MNG_FILTER:
+ self._saveAsMNG(filename, nbFrames)
+ else:
+ _logger.error('Unsupported file filter: %s', nameFilter)
+
+ def _saveAsPNGSerie(self, filename, nbFrames):
+ """Save video as serie of PNG files.
+
+ It adds a counter to the provided filename before the extension.
+
+ :param str filename: filename to use as template
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ # Define filename template
+ nbDigits = int(numpy.log10(nbFrames)) + 1
+ indexFormat = '%%0%dd' % nbDigits
+ extensionIndex = filename.rfind('.')
+ filenameFormat = \
+ filename[:extensionIndex] + indexFormat + filename[extensionIndex:]
+
+ try:
+ for index, image in enumerate(self._video360(nbFrames)):
+ image.save(filenameFormat % index)
+ except GeneratorExit:
+ pass
+
+ def _saveAsMNG(self, filename, nbFrames):
+ """Save video as MNG file.
+
+ :param str filename: filename to use
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ frames = (convertQImageToArray(im) for im in self._video360(nbFrames))
+ try:
+ with open(filename, 'wb') as file_:
+ for chunk in mng.convert(frames, nb_images=nbFrames):
+ file_.write(chunk)
+ except GeneratorExit:
+ os.remove(filename) # Saving aborted, delete file
+
+ def _video360(self, nbFrames):
+ """Run the video and provides the images
+
+ :param int nbFrames: The number of frames to generate for
+ :return: Iterator of QImage of the video sequence
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ angleStep = 360. / nbFrames
+
+ # Create progress bar dialog
+ dialog = qt.QDialog(plot3d)
+ dialog.setWindowTitle('Record Video')
+ layout = qt.QVBoxLayout(dialog)
+ progress = qt.QProgressBar()
+ progress.setRange(0, nbFrames)
+ layout.addWidget(progress)
+
+ btnBox = qt.QDialogButtonBox(qt.QDialogButtonBox.Abort)
+ btnBox.rejected.connect(dialog.reject)
+ layout.addWidget(btnBox)
+
+ dialog.setModal(True)
+ dialog.show()
+
+ qapp = qt.QApplication.instance()
+
+ for frame in range(nbFrames):
+ progress.setValue(frame)
+ image = plot3d.grabGL()
+ yield image
+ plot3d.viewport.orbitCamera('left', angleStep)
+ qapp.processEvents()
+ if not dialog.isVisible():
+ break # It as been rejected by the abort button
+ else:
+ dialog.accept()
+
+ if dialog.result() == qt.QDialog.Rejected:
+ raise GeneratorExit('Aborted')
diff --git a/silx/gui/plot3d/Plot3DToolBar.py b/silx/gui/plot3d/Plot3DToolBar.py
new file mode 100644
index 0000000..cf11362
--- /dev/null
+++ b/silx/gui/plot3d/Plot3DToolBar.py
@@ -0,0 +1,119 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a toolbar with tools for a Plot3DWidget.
+
+It provides:
+
+- Copy
+- Save
+- Print
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+import logging
+
+from silx.gui import qt
+
+from . import Plot3DActions
+
+_logger = logging.getLogger(__name__)
+
+
+class Plot3DToolBar(qt.QToolBar):
+ """Toolbar providing icons to copy, save and print the OpenGL scene
+
+ :param parent: See :class:`QWidget`
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, title='Plot3D'):
+ super(Plot3DToolBar, self).__init__(title, parent)
+
+ self._plot3d = None
+
+ self._copyAction = Plot3DActions.CopyAction(parent=self)
+ self.addAction(self._copyAction)
+
+ self._saveAction = Plot3DActions.SaveAction(parent=self)
+ self.addAction(self._saveAction)
+
+ self._videoAction = Plot3DActions.VideoAction(parent=self)
+ self.addAction(self._videoAction)
+
+ self._printAction = Plot3DActions.PrintAction(parent=self)
+ self.addAction(self._printAction)
+
+ def setPlot3DWidget(self, widget):
+ """Set the Plot3DWidget this toolbar is associated with
+
+ :param Plot3DWidget widget: The widget to copy/save/print
+ """
+ self._plot3d = widget
+ self.getCopyAction().setPlot3DWidget(widget)
+ self.getSaveAction().setPlot3DWidget(widget)
+ self.getVideoRecordAction().setPlot3DWidget(widget)
+ self.getPrintAction().setPlot3DWidget(widget)
+
+ def getPlot3DWidget(self):
+ """Return the Plot3DWidget associated to this toolbar.
+
+ If no widget is associated, it returns None.
+
+ :rtype: qt.QWidget
+ """
+ return self._plot3d
+
+ def getCopyAction(self):
+ """Returns the QAction performing copy to clipboard of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._copyAction
+
+ def getSaveAction(self):
+ """Returns the QAction performing save to file of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._saveAction
+
+ def getVideoRecordAction(self):
+ """Returns the QAction performing record video of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._videoAction
+
+ def getPrintAction(self):
+ """Returns the QAction performing printing of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._printAction
diff --git a/silx/gui/plot3d/Plot3DWidget.py b/silx/gui/plot3d/Plot3DWidget.py
new file mode 100644
index 0000000..9c9da0c
--- /dev/null
+++ b/silx/gui/plot3d/Plot3DWidget.py
@@ -0,0 +1,341 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a Qt widget embedding an OpenGL scene."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+import logging
+
+from silx.gui import qt
+from silx.gui.plot.Colors import rgba
+from silx.gui.plot3d import Plot3DActions
+from .._utils import convertArrayToQImage
+
+from .._glutils import gl
+from .scene import interaction, primitives, transform
+from . import scene
+
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _OverviewViewport(scene.Viewport):
+ """A scene displaying the orientation of the data in another scene.
+
+ :param Camera camera: The camera to track.
+ """
+
+ def __init__(self, camera=None):
+ super(_OverviewViewport, self).__init__()
+ self.size = 100, 100
+
+ self.scene.transforms = [transform.Scale(2.5, 2.5, 2.5)]
+
+ axes = primitives.Axes()
+ self.scene.children.append(axes)
+
+ if camera is not None:
+ camera.addListener(self._cameraChanged)
+
+ def _cameraChanged(self, source):
+ """Listen to camera in other scene for transformation updates.
+
+ Sync the overview camera to point in the same direction
+ but from a sphere centered on origin.
+ """
+ position = -12. * source.extrinsic.direction
+ self.camera.extrinsic.position = position
+
+ self.camera.extrinsic.setOrientation(
+ source.extrinsic.direction, source.extrinsic.up)
+
+
+class Plot3DWidget(qt.QGLWidget):
+ """QGLWidget with a 3D viewport and an overview."""
+
+ def __init__(self, parent=None):
+ if not qt.QGLFormat.hasOpenGL(): # Check if any OpenGL is available
+ raise RuntimeError(
+ 'OpenGL is not available on this platform: 3D disabled')
+
+ self._devicePixelRatio = 1.0 # Store GL canvas/QWidget ratio
+ self._isOpenGL21 = False
+ self._firstRender = True
+
+ format_ = qt.QGLFormat()
+ format_.setRgba(True)
+ format_.setDepth(False)
+ format_.setStencil(False)
+ format_.setVersion(2, 1)
+ format_.setDoubleBuffer(True)
+
+ super(Plot3DWidget, self).__init__(format_, parent)
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self._copyAction = Plot3DActions.CopyAction(parent=self, plot3d=self)
+ self.addAction(self._copyAction)
+
+ self._updating = False # True if an update is requested
+
+ # Main viewport
+ self.viewport = scene.Viewport()
+ self.viewport.background = 0.2, 0.2, 0.2, 1.
+
+ sceneScale = transform.Scale(1., 1., 1.)
+ self.viewport.scene.transforms = [sceneScale,
+ transform.Translate(0., 0., 0.)]
+
+ # Overview area
+ self.overview = _OverviewViewport(self.viewport.camera)
+
+ self.setBackgroundColor((0.2, 0.2, 0.2, 1.))
+
+ # Window describing on screen area to render
+ self.window = scene.Window(mode='framebuffer')
+ self.window.viewports = [self.viewport, self.overview]
+
+ self.eventHandler = interaction.CameraControl(
+ self.viewport, orbitAroundCenter=False,
+ mode='position', scaleTransform=sceneScale,
+ selectCB=None)
+
+ self.viewport.addListener(self._redraw)
+
+ def setProjection(self, projection):
+ """Change the projection in use.
+
+ :param str projection: In 'perspective', 'orthographic'.
+ """
+ if projection == 'orthographic':
+ projection = transform.Orthographic(size=self.viewport.size)
+ elif projection == 'perspective':
+ projection = transform.Perspective(fovy=30.,
+ size=self.viewport.size)
+ else:
+ raise RuntimeError('Unsupported projection: %s' % projection)
+
+ self.viewport.camera.intrinsic = projection
+ self.viewport.resetCamera()
+
+ def getProjection(self):
+ """Return the current camera projection mode as a str.
+
+ See :meth:`setProjection`
+ """
+ projection = self.viewport.camera.intrinsic
+ if isinstance(projection, transform.Orthographic):
+ return 'orthographic'
+ elif isinstance(projection, transform.Perspective):
+ return 'perspective'
+ else:
+ raise RuntimeError('Unknown projection in use')
+
+ def setBackgroundColor(self, color):
+ """Set the background color of the OpenGL view.
+
+ :param color: RGB color of the isosurface: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ self.viewport.background = color
+ self.overview.background = color[0]*0.5, color[1]*0.5, color[2]*0.5, 1.
+
+ def getBackgroundColor(self):
+ """Returns the RGBA background color (QColor)."""
+ return qt.QColor.fromRgbF(*self.viewport.background)
+
+ def centerScene(self):
+ """Position the center of the scene at the center of rotation."""
+ self.viewport.resetCamera()
+
+ def resetZoom(self, face='front'):
+ """Reset the camera position to a default.
+
+ :param str face: The direction the camera is looking at:
+ side, front, back, top, bottom, right, left.
+ Default: front.
+ """
+ self.viewport.camera.extrinsic.reset(face=face)
+ self.centerScene()
+
+ def _redraw(self, source=None):
+ """Viewport listener to require repaint"""
+ if not self._updating and self.viewport.dirty:
+ self._updating = True # Mark that an update is requested
+ self.update() # Queued repaint (i.e., asynchronous)
+
+ def sizeHint(self):
+ return qt.QSize(400, 300)
+
+ def initializeGL(self):
+ # Check if OpenGL2 is available
+ versionflags = self.format().openGLVersionFlags()
+ self._isOpenGL21 = bool(versionflags & qt.QGLFormat.OpenGL_Version_2_1)
+ if not self._isOpenGL21:
+ _logger.error(
+ '3D rendering is disabled: OpenGL 2.1 not available')
+
+ messageBox = qt.QMessageBox(parent=self)
+ messageBox.setIcon(qt.QMessageBox.Critical)
+ messageBox.setWindowTitle('Error')
+ messageBox.setText('3D rendering is disabled.\n\n'
+ 'Reason: OpenGL 2.1 is not available.')
+ messageBox.addButton(qt.QMessageBox.Ok)
+ messageBox.setWindowModality(qt.Qt.WindowModal)
+ messageBox.setAttribute(qt.Qt.WA_DeleteOnClose)
+ messageBox.show()
+
+ def paintGL(self):
+ # In case paintGL is called by the system and not through _redraw,
+ # Mark as updating.
+ self._updating = True
+
+ if hasattr(self, 'windowHandle'): # Qt 5
+ devicePixelRatio = self.windowHandle().devicePixelRatio()
+ if devicePixelRatio != self._devicePixelRatio:
+ # Move window from one screen to another one
+ self._devicePixelRatio = devicePixelRatio
+ # Resize might not be called, so call it explicitly
+ self.resizeGL(int(self.width() * devicePixelRatio),
+ int(self.height() * devicePixelRatio))
+
+ if not self._isOpenGL21:
+ # Cannot render scene, just clear the color buffer.
+ ox, oy = self.viewport.origin
+ w, h = self.viewport.size
+ gl.glViewport(ox, oy, w, h)
+
+ gl.glClearColor(*self.viewport.background)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ else:
+ # Update near and far planes only if viewport needs refresh
+ if self.viewport.dirty:
+ self.viewport.adjustCameraDepthExtent()
+
+ self.window.render(self.context(), self._devicePixelRatio)
+
+ if self._firstRender: # TODO remove this ugly hack
+ self._firstRender = False
+ self.centerScene()
+ self._updating = False
+
+ def resizeGL(self, width, height):
+ self.window.size = width, height
+ self.viewport.size = width, height
+ overviewWidth, overviewHeight = self.overview.size
+ self.overview.origin = width - overviewWidth, height - overviewHeight
+
+ def grabGL(self):
+ """Renders the OpenGL scene into a numpy array
+
+ :returns: OpenGL scene RGB rasterization
+ :rtype: QImage
+ """
+ if not self._isOpenGL21:
+ _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ height, width = self.window.shape
+ image = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+
+ else:
+ self.makeCurrent()
+ image = self.window.grab(qt.QGLContext.currentContext())
+
+ return convertArrayToQImage(image)
+
+ def wheelEvent(self, event):
+ xpixel = event.x() * self._devicePixelRatio
+ ypixel = event.y() * self._devicePixelRatio
+ if hasattr(event, 'delta'): # Qt4
+ angle = event.delta() / 8.
+ else: # Qt5
+ angle = event.angleDelta().y() / 8.
+ event.accept()
+
+ if angle != 0:
+ self.makeCurrent()
+ self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle)
+
+ def keyPressEvent(self, event):
+ keycode = event.key()
+ # No need to accept QKeyEvent
+
+ converter = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+ direction = converter.get(keycode, None)
+ if direction is not None:
+ if event.modifiers() == qt.Qt.ControlModifier:
+ self.viewport.camera.rotate(direction)
+ elif event.modifiers() == qt.Qt.ShiftModifier:
+ self.viewport.moveCamera(direction)
+ else:
+ self.viewport.orbitCamera(direction)
+
+ else:
+ # Key not handled, call base class implementation
+ super(Plot3DWidget, self).keyPressEvent(event)
+
+ # Mouse events #
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def mousePressEvent(self, event):
+ xpixel = event.x() * self._devicePixelRatio
+ ypixel = event.y() * self._devicePixelRatio
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ self.makeCurrent()
+ self.eventHandler.handleEvent('press', xpixel, ypixel, btn)
+
+ def mouseMoveEvent(self, event):
+ xpixel = event.x() * self._devicePixelRatio
+ ypixel = event.y() * self._devicePixelRatio
+ event.accept()
+
+ self.makeCurrent()
+ self.eventHandler.handleEvent('move', xpixel, ypixel)
+
+ def mouseReleaseEvent(self, event):
+ xpixel = event.x() * self._devicePixelRatio
+ ypixel = event.y() * self._devicePixelRatio
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ self.makeCurrent()
+ self.eventHandler.handleEvent('release', xpixel, ypixel, btn)
diff --git a/silx/gui/plot3d/Plot3DWindow.py b/silx/gui/plot3d/Plot3DWindow.py
new file mode 100644
index 0000000..4658d38
--- /dev/null
+++ b/silx/gui/plot3d/Plot3DWindow.py
@@ -0,0 +1,94 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a QMainWindow with a 3D scene and associated toolbar.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+from silx.gui import qt
+
+from .Plot3DToolBar import Plot3DToolBar
+from .Plot3DWidget import Plot3DWidget
+from .ViewpointToolBar import ViewpointToolBar
+
+
+class Plot3DWindow(qt.QMainWindow):
+ """QGLWidget with a 3D viewport and an overview."""
+
+ def __init__(self, parent=None):
+ super(Plot3DWindow, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self._plot3D = Plot3DWidget()
+ self.setCentralWidget(self._plot3D)
+ self.addToolBar(
+ ViewpointToolBar(parent=self, plot3D=self._plot3D))
+ toolbar = Plot3DToolBar(parent=self)
+ toolbar.setPlot3DWidget(self._plot3D)
+ self.addToolBar(toolbar)
+ self.addActions(toolbar.actions())
+
+ def getPlot3DWidget(self):
+ """Get the :class:`Plot3DWidget` of this window"""
+ return self._plot3D
+
+ # Proxy to Plot3DWidget
+
+ def setProjection(self, projection):
+ return self._plot3D.setProjection(projection)
+
+ setProjection.__doc__ = Plot3DWidget.setProjection.__doc__
+
+ def getProjection(self):
+ return self._plot3D.getProjection()
+
+ getProjection.__doc__ = Plot3DWidget.getProjection.__doc__
+
+ def centerScene(self):
+ return self._plot3D.centerScene()
+
+ centerScene.__doc__ = Plot3DWidget.centerScene.__doc__
+
+ def resetZoom(self):
+ return self._plot3D.resetZoom()
+
+ resetZoom.__doc__ = Plot3DWidget.resetZoom.__doc__
+
+ def getBackgroundColor(self):
+ return self._plot3D.getBackgroundColor()
+
+ getBackgroundColor.__doc__ = Plot3DWidget.getBackgroundColor.__doc__
+
+ def setBackgroundColor(self, color):
+ return self._plot3D.setBackgroundColor(color)
+
+ setBackgroundColor.__doc__ = Plot3DWidget.setBackgroundColor.__doc__
diff --git a/silx/gui/plot3d/SFViewParamTree.py b/silx/gui/plot3d/SFViewParamTree.py
new file mode 100644
index 0000000..38d4e37
--- /dev/null
+++ b/silx/gui/plot3d/SFViewParamTree.py
@@ -0,0 +1,1467 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides a tree widget to set/view parameters of a ScalarFieldView.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["D. N."]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+import logging
+import sys
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+
+from .ScalarFieldView import Isosurface
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ModelColumns(object):
+ NameColumn, ValueColumn, ColumnMax = range(3)
+ ColumnNames = ['Name', 'Value']
+
+
+class SubjectItem(qt.QStandardItem):
+ """
+ Base class for observers items.
+
+ Subclassing:
+ ------------
+ The following method can/should be reimplemented:
+ - _init
+ - _pullData
+ - _pushData
+ - _setModelData
+ - _subjectChanged
+ - getEditor
+ - getSignals
+ - leftClicked
+ - queryRemove
+ - setEditorData
+
+ Also the following attributes are available:
+ - editable
+ - persistent
+
+ :param subject: object that this item will be observing.
+ """
+
+ editable = False
+ """ boolean: set to True to make the item editable. """
+
+ persistent = False
+ """
+ boolean: set to True to make the editor persistent.
+ See : Qt.QAbstractItemView.openPersistentEditor
+ """
+
+ def __init__(self, subject, *args):
+
+ super(SubjectItem, self).__init__(*args)
+
+ self.setEditable(self.editable)
+
+ self.__subject = None
+ self.subject = subject
+
+ def setData(self, value, role=qt.Qt.UserRole, pushData=True):
+ """
+ Overloaded method from QStandardItem. The pushData keyword tells
+ the item to push data to the subject if the role is equal to EditRole.
+ This is useful to let this method know if the setData method was called
+ internaly or from the view.
+
+ :param value: the value ti set to data
+ :param role: role in the item
+ :param pushData: if True push value in the existing data.
+ """
+ if role == qt.Qt.EditRole and pushData:
+ setValue = self._pushData(value, role)
+ if setValue != value:
+ value = setValue
+ super(SubjectItem, self).setData(value, role)
+
+ subject = property(lambda self: self.__subject)
+
+ @subject.setter
+ def subject(self, subject):
+ if self.__subject is not None:
+ raise ValueError('Subject already set '
+ ' (subject change not supported).')
+ self.__subject = subject
+ if subject is not None:
+ self._init()
+ self._connectSignals()
+
+ def _connectSignals(self):
+ """
+ Connects the signals. Called when the subject is set.
+ """
+
+ def gen_slot(_sigIdx):
+ def slotfn(*args, **kwargs):
+ self._subjectChanged(signalIdx=_sigIdx,
+ args=args,
+ kwargs=kwargs)
+ return slotfn
+
+ if self.__subject is not None:
+ self.__slots = slots = []
+
+ signals = self.getSignals()
+
+ if signals:
+ if not isinstance(signals, (list, tuple)):
+ signals = [signals]
+ for sigIdx, signal in enumerate(signals):
+ slot = gen_slot(sigIdx)
+ signal.connect(slot)
+ slots.append((signal, slot))
+
+ def _disconnectSignals(self):
+ """
+ Disconnects all subject's signal
+ """
+ if self.__slots:
+ for signal, slot in self.__slots:
+ try:
+ signal.disconnect(slot)
+ except TypeError:
+ pass
+
+ def _enableRow(self, enable):
+ """
+ Set the enabled state for this cell, or for the whole row
+ if this item has a parent.
+
+ :param bool enable: True if we wan't to enable the cell
+ """
+ parent = self.parent()
+ model = self.model()
+ if model is None or parent is None:
+ # no parent -> no siblings
+ self.setEnabled(enable)
+ return
+
+ for col in range(model.columnCount()):
+ sibling = parent.child(self.row(), col)
+ sibling.setEnabled(enable)
+
+ #################################################################
+ # Overloadable methods
+ #################################################################
+
+ def getSignals(self):
+ """
+ Returns the list of this items subject's signals that
+ this item will be listening to.
+
+ :return: list.
+ """
+ return None
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ """
+ Called when one of the signals is triggered. Default implementation
+ just calls _pullData, compares the result to the current value stored
+ as Qt.EditRole, and stores the new value if it is different. It also
+ stores its str representation as Qt.DisplayRole
+
+ :param signalIdx: index of the triggered signal. The value passed
+ is the same as the signal position in the list returned by
+ SubjectItem.getSignals.
+ :param args: arguments received from the signal
+ :param kwargs: keyword arguments received from the signal
+ """
+ data = self._pullData()
+ if data == self.data(qt.Qt.EditRole):
+ return
+ self.setData(data, role=qt.Qt.DisplayRole, pushData=False)
+ self.setData(data, role=qt.Qt.EditRole, pushData=False)
+
+ def _pullData(self):
+ """
+ Pulls data from the subject.
+
+ :return: subject data
+ """
+ return None
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ """
+ Pushes data to the subject and returns the actual value that was stored
+
+ :return: the value that was stored
+ """
+ return value
+
+ def _init(self):
+ """
+ Called when the subject is set.
+ :return:
+ """
+ self._subjectChanged()
+
+ def getEditor(self, parent, option, index):
+ """
+ Returns the editor widget used to edit this item's data. The arguments
+ are the one passed to the QStyledItemDelegate.createEditor method.
+
+ :param parent: the Qt parent of the editor
+ :param option:
+ :param index:
+ :return:
+ """
+ return None
+
+ def setEditorData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is shown,
+ its purpose it to setup the editors contents. Return False to use
+ the delegate's default behaviour.
+
+ :param editor:
+ :return:
+ """
+ return True
+
+ def _setModelData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is closed,
+ its allows this item to update itself with data from the editor.
+
+ :param editor:
+ :return:
+ """
+ return False
+
+ def queryRemove(self, view=None):
+ """
+ This is called by the view to ask this items if it (the view) can
+ remove it. Return True to let the view know that the item can be
+ removed.
+
+ :param view:
+ :return:
+ """
+ return False
+
+ def leftClicked(self):
+ """
+ This method is called by the view when the item's cell if left clicked.
+
+ :return:
+ """
+ pass
+
+
+# View settings ###############################################################
+
+class ColorItem(SubjectItem):
+ """color item."""
+ editable = True
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ editor.color = self.getColor()
+ editor.sigColorChanged.connect(self._editorSlot)
+ return editor
+
+ def _editorSlot(self, color):
+ self.setData(color, qt.Qt.EditRole)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.setColor(value)
+ return self.getColor()
+
+ def _pullData(self):
+ self.getColor()
+
+ def setColor(self, color):
+ """Override to implement actual color setter"""
+ pass
+
+
+class BackgroundColorItem(ColorItem):
+ itemName = 'Background'
+
+ def setColor(self, color):
+ self.subject.setBackgroundColor(color)
+
+ def getColor(self):
+ return self.subject.getBackgroundColor()
+
+
+class ForegroundColorItem(ColorItem):
+ itemName = 'Foreground'
+
+ def setColor(self, color):
+ self.subject.setForegroundColor(color)
+
+ def getColor(self):
+ return self.subject.getForegroundColor()
+
+
+class HighlightColorItem(ColorItem):
+ itemName = 'Highlight'
+
+ def setColor(self, color):
+ self.subject.setHighlightColor(color)
+
+ def getColor(self):
+ return self.subject.getHighlightColor()
+
+
+class ViewSettingsItem(qt.QStandardItem):
+ """Viewport settings"""
+
+ def __init__(self, subject, *args):
+
+ super(ViewSettingsItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ classes = BackgroundColorItem, ForegroundColorItem, HighlightColorItem
+ for cls in classes:
+ titleItem = qt.QStandardItem(cls.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, cls(subject)])
+
+
+# Data information ############################################################
+
+class DataChangedItem(SubjectItem):
+ """
+ Base class for items listening to ScalarFieldView.sigDataChanged
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ if subject:
+ return subject.sigDataChanged
+ return None
+
+ def _init(self):
+ self._subjectChanged()
+
+
+class DataTypeItem(DataChangedItem):
+ itemName = 'dtype'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ return ((data is not None) and str(data.dtype)) or 'N/A'
+
+
+class DataShapeItem(DataChangedItem):
+ itemName = 'size'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ if data is None:
+ return 'N/A'
+ else:
+ return str(list(reversed(data.shape)))
+
+
+class OffsetItem(DataChangedItem):
+ itemName = 'offset'
+
+ def _pullData(self):
+ offset = self.subject.getTranslation()
+ return ((offset is not None) and str(offset)) or 'N/A'
+
+
+class ScaleItem(DataChangedItem):
+ itemName = 'scale'
+
+ def _pullData(self):
+ scale = self.subject.getScale()
+ return ((scale is not None) and str(scale)) or 'N/A'
+
+
+class DataSetItem(qt.QStandardItem):
+
+ def __init__(self, subject, *args):
+
+ super(DataSetItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ klasses = [DataTypeItem, DataShapeItem, OffsetItem, ScaleItem]
+ for klass in klasses:
+ titleItem = qt.QStandardItem(klass.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, klass(subject)])
+
+
+# Isosurface ##################################################################
+
+class IsoSurfaceRootItem(SubjectItem):
+ """
+ Root (i.e : column index 0) Isosurface item.
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigColorChanged,
+ subject.sigVisibilityChanged]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ color = self.subject.getColor()
+ self.setData(color, qt.Qt.DecorationRole)
+ elif signalIdx == 1:
+ visible = args[0]
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ def _init(self):
+ self.setCheckable(True)
+
+ isosurface = self.subject
+ color = isosurface.getColor()
+ visible = isosurface.isVisible()
+ self.setData(color, qt.Qt.DecorationRole)
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ nameItem = qt.QStandardItem('Level')
+ sliderItem = IsoSurfaceLevelSlider(self.subject)
+ self.appendRow([nameItem, sliderItem])
+
+ nameItem = qt.QStandardItem('Color')
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceColorItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Opacity')
+ nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem()
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaLegendItem(self.subject)
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ def queryRemove(self, view=None):
+ buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel
+ ans = qt.QMessageBox.question(view,
+ 'Remove isosurface',
+ 'Remove the selected iso-surface?',
+ buttons=buttons)
+ if ans == qt.QMessageBox.Ok:
+ sfview = self.subject.parent()
+ if sfview:
+ sfview.removeIsosurface(self.subject)
+ return False
+ return False
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ visible = self.subject.isVisible()
+ if checked != visible:
+ self.subject.setVisible(checked)
+
+
+class IsoSurfaceLevelItem(SubjectItem):
+ """
+ Base class for the isosurface level items.
+ """
+ editable = True
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigLevelChanged,
+ subject.sigVisibilityChanged]
+
+ def setEditorData(self, editor):
+ return False
+
+ def _pullData(self):
+ return self.subject.getLevel()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setLevel(value)
+ return self.subject.getLevel()
+
+
+class _IsoLevelSlider(qt.QSlider):
+ """QSlider used for iso-surface level"""
+
+ def __init__(self, parent, subject):
+ super(_IsoLevelSlider, self).__init__(parent=parent)
+ self.subject = subject
+
+ self.sliderReleased.connect(self.__sliderReleased)
+
+ self.subject.sigLevelChanged.connect(self.setLevel)
+ self.subject.parent().sigDataChanged.connect(self.__dataChanged)
+
+ def setLevel(self, level):
+ """Set slider from iso-surface level"""
+ dataRange = self.subject.parent().getDataRange()
+
+ if dataRange is not None and None not in dataRange:
+ width = dataRange[1] - dataRange[0]
+ if width > 0:
+ sliderWidth = self.maximum() - self.minimum()
+ sliderPosition = sliderWidth * (level - dataRange[0]) / width
+ self.setValue(sliderPosition)
+
+ def __dataChanged(self):
+ """Handles data update to refresh slider range if needed"""
+ self.setLevel(self.subject.getLevel())
+
+ def __sliderReleased(self):
+ value = self.value()
+ dataRange = self.subject.parent().getDataRange()
+ width = dataRange[1] - dataRange[0]
+ sliderWidth = self.maximum() - self.minimum()
+ level = dataRange[0] + width * value / sliderWidth
+ self.subject.setLevel(level)
+
+
+class IsoSurfaceLevelSlider(IsoSurfaceLevelItem):
+ """
+ Isosurface level item with a slider editor.
+ """
+ nTicks = 1000
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ editor = _IsoLevelSlider(parent, self.subject)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(self.nTicks)
+
+ editor.setSingleStep(1)
+
+ editor.setLevel(self.subject.getLevel())
+ return editor
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceColorItem(SubjectItem):
+ """
+ Isosurface color item.
+ """
+ editable = True
+ persistent = True
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ color = self.subject.getColor()
+ color.setAlpha(255)
+ editor.color = color
+ editor.sigColorChanged.connect(self.__editorChanged)
+ return editor
+
+ def __editorChanged(self, color):
+ color.setAlpha(self.subject.getColor().alpha())
+ self.subject.setColor(color)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setColor(value)
+ return self.subject.getColor()
+
+
+class QColorEditor(qt.QWidget):
+ """
+ QColor editor.
+ """
+ sigColorChanged = qt.Signal(object)
+
+ color = property(lambda self: qt.QColor(self.__color))
+
+ @color.setter
+ def color(self, color):
+ self._setColor(color)
+ self.__previousColor = color
+
+ def __init__(self, *args, **kwargs):
+ super(QColorEditor, self).__init__(*args, **kwargs)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ button = qt.QToolButton()
+ icon = qt.QIcon(qt.QPixmap(32, 32))
+ button.setIcon(icon)
+ layout.addWidget(button)
+ button.clicked.connect(self.__showColorDialog)
+ layout.addStretch(1)
+
+ self.__color = None
+ self.__previousColor = None
+
+ def sizeHint(self):
+ return qt.QSize(0, 0)
+
+ def _setColor(self, qColor):
+ button = self.findChild(qt.QToolButton)
+ pixmap = qt.QPixmap(32, 32)
+ pixmap.fill(qColor)
+ button.setIcon(qt.QIcon(pixmap))
+ self.__color = qColor
+
+ def __showColorDialog(self):
+ dialog = qt.QColorDialog(parent=self)
+ if sys.platform == 'darwin':
+ # Use of native color dialog on macos might cause problems
+ dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
+
+ self.__previousColor = self.__color
+ dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
+ dialog.setModal(True)
+ dialog.currentColorChanged.connect(self.__colorChanged)
+ dialog.finished.connect(self.__dialogClosed)
+ dialog.show()
+
+ def __colorChanged(self, color):
+ self.__color = color
+ self._setColor(color)
+ self.sigColorChanged.emit(color)
+
+ def __dialogClosed(self, result):
+ if result == qt.QDialog.Rejected:
+ self.__colorChanged(self.__previousColor)
+ self.__previousColor = None
+
+
+class IsoSurfaceAlphaItem(SubjectItem):
+ """
+ Isosurface alpha item.
+ """
+ editable = True
+ persistent = True
+
+ def _init(self):
+ pass
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QSlider(parent)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(255)
+
+ color = self.subject.getColor()
+ editor.setValue(color.alpha())
+
+ editor.valueChanged.connect(self.__editorChanged)
+
+ return editor
+
+ def __editorChanged(self, value):
+ color = self.subject.getColor()
+ color.setAlpha(value)
+ self.subject.setColor(color)
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceAlphaLegendItem(SubjectItem):
+ """Legend to place under opacity slider"""
+
+ editable = False
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(qt.QLabel('0'))
+ layout.addStretch(1)
+ layout.addWidget(qt.QLabel('1'))
+
+ editor = qt.QWidget(parent)
+ editor.setLayout(layout)
+ return editor
+
+
+class IsoSurfaceCount(SubjectItem):
+ """
+ Item displaying the number of isosurfaces.
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _pullData(self):
+ return len(self.subject.getIsosurfaces())
+
+
+class IsoSurfaceAddRemoveWidget(qt.QWidget):
+
+ sigViewTask = qt.Signal(str)
+ """Signal for the tree view to perform some task"""
+
+ def __init__(self, parent, item):
+ super(IsoSurfaceAddRemoveWidget, self).__init__(parent)
+ self._item = item
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ addBtn = qt.QToolButton()
+ addBtn.setText('+')
+ addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(addBtn)
+ addBtn.clicked.connect(self.__addClicked)
+
+ removeBtn = qt.QToolButton()
+ removeBtn.setText('-')
+ removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(removeBtn)
+ removeBtn.clicked.connect(self.__removeClicked)
+
+ layout.addStretch(1)
+
+ def __addClicked(self):
+ sfview = self._item.subject
+ if not sfview:
+ return
+ dataRange = sfview.getDataRange()
+ if dataRange is None:
+ dataRange = [0, 1]
+
+ sfview.addIsosurface(numpy.mean(dataRange), '#0000FF')
+
+ def __removeClicked(self):
+ self.sigViewTask.emit('remove_iso')
+
+
+class IsoSurfaceAddRemoveItem(SubjectItem):
+ """
+ Item displaying a simple QToolButton allowing to add an isosurface.
+ """
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ return IsoSurfaceAddRemoveWidget(parent, self)
+
+
+class IsoSurfaceGroup(SubjectItem):
+ """
+ Root item for the list of isosurface items.
+ """
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__addIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+ elif signalIdx == 1:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__removeIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+
+ def __addIsosurface(self, isosurface):
+ valueItem = IsoSurfaceRootItem(subject=isosurface)
+ nameItem = IsoSurfaceLevelItem(subject=isosurface)
+ self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem])
+
+ def __removeIsosurface(self, isosurface):
+ for row in range(self.rowCount()):
+ child = self.child(row)
+ subject = getattr(child, 'subject', None)
+ if subject == isosurface:
+ self.takeRow(row)
+ break
+
+ def _init(self):
+ nameItem = IsoSurfaceAddRemoveItem(self.subject)
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ subject = self.subject
+ isosurfaces = subject.getIsosurfaces()
+ for isosurface in isosurfaces:
+ self.__addIsosurface(isosurface)
+
+
+# Cutting Plane ###############################################################
+
+class ColormapBase(SubjectItem):
+ """
+ Mixin class for colormap items.
+ """
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigColormapChanged]
+
+
+class PlaneMinRangeItem(ColormapBase):
+ """
+ colormap minVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return colormap.getVMin()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self._setVMin(value)
+
+ def _setVMin(self, value):
+ cutPlane = self.subject.getCutPlanes()[0]
+ colormap = cutPlane.getColormap()
+ vMin = value
+ vMax = colormap.getVMax()
+
+ if vMax is not None and value > vMax:
+ vMin = vMax
+ vMax = value
+ cutPlane.setColormap(name=colormap.getName(),
+ norm=colormap.getNorm(),
+ vmin=vMin,
+ vmax=vMax)
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QLineEdit(parent)
+ editor.setValidator(qt.QDoubleValidator())
+ return editor
+
+ def setEditorData(self, editor):
+ editor.setText(str(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = float(editor.text())
+ self._setVMin(value)
+ return True
+
+
+class PlaneMaxRangeItem(ColormapBase):
+ """
+ colormap maxVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return self.subject.getCutPlanes()[0].getColormap().getVMax()
+
+ def _setVMax(self, value):
+ cutPlane = self.subject.getCutPlanes()[0]
+ colormap = cutPlane.getColormap()
+ vMin = colormap.getVMin()
+ vMax = value
+ if vMin is not None and value < vMin:
+ vMax = vMin
+ vMin = value
+ cutPlane.setColormap(name=colormap.getName(),
+ norm=colormap.getNorm(),
+ vmin=vMin,
+ vmax=vMax)
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QLineEdit(parent)
+ editor.setValidator(qt.QDoubleValidator())
+ return editor
+
+ def setEditorData(self, editor):
+ editor.setText(str(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = float(editor.text())
+ self._setVMax(value)
+ return True
+
+
+class PlaneOrientationItem(SubjectItem):
+ """
+ Plane orientation item.
+ Editor is a QComboBox.
+ """
+ editable = True
+
+ _PLANE_ACTIONS = (
+ ('3d-plane-normal-x', 'Plane 0',
+ 'Set plane perpendicular to red axis', (1., 0., 0.)),
+ ('3d-plane-normal-y', 'Plane 1',
+ 'Set plane perpendicular to green axis', (0., 1., 0.)),
+ ('3d-plane-normal-z', 'Plane 2',
+ 'Set plane perpendicular to blue axis', (0., 0., 1.)),
+ )
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigPlaneChanged]
+
+ def _pullData(self):
+ currentNormal = self.subject.getCutPlanes()[0].getNormal()
+ for _, text, _, normal in self._PLANE_ACTIONS:
+ if numpy.array_equal(normal, currentNormal):
+ return text
+ return ''
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ for iconName, text, tooltip, normal in self._PLANE_ACTIONS:
+ editor.addItem(getQIcon(iconName), text)
+ editor.currentIndexChanged[int].connect(self.__editorChanged)
+ return editor
+
+ def __editorChanged(self, index):
+ normal = self._PLANE_ACTIONS[index][3]
+ plane = self.subject.getCutPlanes()[0]
+ plane.setNormal(normal)
+ plane.moveToCenter()
+
+ def setEditorData(self, editor):
+ currentText = self._pullData()
+ index = 0
+ for normIdx, (_, text, _, _) in enumerate(self._PLANE_ACTIONS):
+ if text == currentText:
+ index = normIdx
+ break
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class PlaneInterpolationItem(SubjectItem):
+ """Toggle cut plane interpolation method: nearest or linear.
+
+ Item is checkable
+ """
+
+ def _init(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self.setCheckable(True)
+ self.setCheckState(
+ qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigInterpolationChanged]
+
+ def leftClicked(self):
+ checked = self.checkState() == qt.Qt.Checked
+ self._setInterpolation('linear' if checked else 'nearest')
+
+ def _pullData(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self._setInterpolation(interpolation)
+ return interpolation[0].upper() + interpolation[1:]
+
+ def _setInterpolation(self, interpolation):
+ self.subject.getCutPlanes()[0].setInterpolation(interpolation)
+
+
+class PlaneColormapItem(ColormapBase):
+ """
+ colormap name item.
+ Editor is a QComboBox
+ """
+ editable = True
+
+ listValues = ['gray', 'reversed gray',
+ 'temperature', 'red',
+ 'green', 'blue']
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+ editor.currentIndexChanged[int].connect(self.__editorChanged)
+
+ return editor
+
+ def __editorChanged(self, index):
+ colorMapName = self.listValues[index]
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ self.subject.getCutPlanes()[0].setColormap(name=colorMapName,
+ norm=colorMap.getNorm(),
+ vmin=colorMap.getVMin(),
+ vmax=colorMap.getVMax())
+
+ def setEditorData(self, editor):
+ colormapName = self.subject.getCutPlanes()[0].getColormap().getName()
+ index = self.listValues.index(colormapName)
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getName()
+
+
+class PlaneAutoScaleItem(ColormapBase):
+ """
+ colormap autoscale item.
+ Item is checkable.
+ """
+
+ def _init(self):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ self.setCheckable(True)
+ self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ self._setAutoScale(checked)
+
+ def _setAutoScale(self, auto):
+ view3d = self.subject
+ cutPlane = view3d.getCutPlanes()[0]
+ colormap = cutPlane.getColormap()
+
+ if auto != colormap.isAutoscale():
+ if auto:
+ vMin = vMax = None
+ else:
+ dataRange = view3d.getDataRange()
+ if dataRange is None or None in dataRange:
+ vMin = vMax = None
+ else:
+ vMin, vMax = dataRange
+ cutPlane.setColormap(colormap.getName(),
+ colormap.getNorm(),
+ vMin,
+ vMax)
+
+ def _pullData(self):
+ auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale()
+ self._setAutoScale(auto)
+ if auto:
+ data = 'Auto'
+ else:
+ data = 'User'
+ return data
+
+
+class NormalizationNode(ColormapBase):
+ """
+ colormap normalization item.
+ Item is a QComboBox.
+ """
+ editable = True
+ listValues = ['linear', 'log']
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+ editor.currentIndexChanged[int].connect(self.__editorChanged)
+
+ return editor
+
+ def __editorChanged(self, index):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ normalization = self.listValues[index]
+ self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(),
+ norm=normalization,
+ vmin=colorMap.getVMin(),
+ vmax=colorMap.getVMax())
+
+ def setEditorData(self, editor):
+ normalization = self.subject.getCutPlanes()[0].getColormap().getNorm()
+ index = self.listValues.index(normalization)
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getNorm()
+
+
+class PlaneGroup(SubjectItem):
+ """
+ Root Item for the plane items.
+ """
+ def _init(self):
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ nameItem = PlaneVisibleItem(self.subject, 'Visible')
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Colormap')
+ nameItem.setEditable(False)
+ valueItem = PlaneColormapItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Normalization')
+ nameItem.setEditable(False)
+ valueItem = NormalizationNode(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Orientation')
+ nameItem.setEditable(False)
+ valueItem = PlaneOrientationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Interpolation')
+ nameItem.setEditable(False)
+ valueItem = PlaneInterpolationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Autoscale')
+ nameItem.setEditable(False)
+ valueItem = PlaneAutoScaleItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Min')
+ nameItem.setEditable(False)
+ valueItem = PlaneMinRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Max')
+ nameItem.setEditable(False)
+ valueItem = PlaneMaxRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+
+class PlaneVisibleItem(SubjectItem):
+ """
+ Plane visibility item.
+ Item is checkable.
+ """
+ def _init(self):
+ plane = self.subject.getCutPlanes()[0]
+ self.setCheckable(True)
+ self.setCheckState((plane.isVisible() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ plane = self.subject.getCutPlanes()[0]
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != plane.isVisible():
+ plane.setVisible(checked)
+ if plane.isVisible():
+ plane.moveToCenter()
+
+
+# Tree ########################################################################
+
+class ItemDelegate(qt.QStyledItemDelegate):
+ """
+ Delegate for the QTreeView filled with SubjectItems.
+ """
+
+ sigDelegateEvent = qt.Signal(str)
+
+ def __init__(self, parent=None):
+ super(ItemDelegate, self).__init__(parent)
+
+ def createEditor(self, parent, option, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem):
+ editor = item.getEditor(parent, option, index)
+ if editor:
+ editor.setAutoFillBackground(True)
+ if hasattr(editor, 'sigViewTask'):
+ editor.sigViewTask.connect(self.__viewTask)
+ return editor
+
+ editor = super(ItemDelegate, self).createEditor(parent,
+ option,
+ index)
+ return editor
+
+ def updateEditorGeometry(self, editor, option, index):
+ editor.setGeometry(option.rect)
+
+ def setEditorData(self, editor, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem) and item.setEditorData(editor):
+ return
+ super(ItemDelegate, self).setEditorData(editor, index)
+
+ def setModelData(self, editor, model, index):
+ item = index.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem) and item._setModelData(editor):
+ return
+ super(ItemDelegate, self).setModelData(editor, model, index)
+
+ def __viewTask(self, task):
+ self.sigDelegateEvent.emit(task)
+
+
+class TreeView(qt.QTreeView):
+ """
+ TreeView displaying the SubjectItems for the ScalarFieldView.
+ """
+
+ def __init__(self, parent=None):
+ super(TreeView, self).__init__(parent)
+ self.__openedIndex = None
+
+ self.setIconSize(qt.QSize(16, 16))
+
+ header = self.header()
+ if hasattr(header, 'setSectionResizeMode'): # Qt5
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ else: # Qt4
+ header.setResizeMode(qt.QHeaderView.ResizeToContents)
+
+ delegate = ItemDelegate()
+ self.setItemDelegate(delegate)
+ delegate.sigDelegateEvent.connect(self.__delegateEvent)
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ self.clicked.connect(self.__clicked)
+
+ def setSfView(self, sfView):
+ """
+ Sets the ScalarFieldView this view is controlling.
+
+ :param sfView: A `ScalarFieldView`
+ """
+ model = qt.QStandardItemModel()
+ model.setColumnCount(ModelColumns.ColumnMax)
+ model.setHorizontalHeaderLabels(['Name', 'Value'])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([ViewSettingsItem(sfView, 'Style'), item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([DataSetItem(sfView, 'Data'), item])
+
+ item = IsoSurfaceCount(sfView)
+ item.setEditable(False)
+ model.appendRow([IsoSurfaceGroup(sfView, 'Isosurfaces'), item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item])
+
+ self.setModel(model)
+
+ def setModel(self, model):
+ """
+ Reimplementation of the QTreeView.setModel method. It connects the
+ rowsRemoved signal and opens the persistent editors.
+
+ :param qt.QStandardItemModel model: the model
+ """
+
+ prevModel = self.model()
+ if prevModel:
+ self.__openPersistentEditors(qt.QModelIndex(), False)
+ try:
+ prevModel.rowsRemoved.disconnect(self.rowsRemoved)
+ except TypeError:
+ pass
+
+ super(TreeView, self).setModel(model)
+ model.rowsRemoved.connect(self.rowsRemoved)
+ self.__openPersistentEditors(qt.QModelIndex())
+
+ def __openPersistentEditors(self, parent=None, openEditor=True):
+ """
+ Opens or closes the items persistent editors.
+
+ :param qt.QModelIndex parent: starting index, or None if the whole tree
+ is to be considered.
+ :param bool openEditor: True to open the editors, False to close them.
+ """
+ model = self.model()
+
+ if not model:
+ return
+
+ if not parent or not parent.isValid():
+ parent = self.model().invisibleRootItem().index()
+
+ if openEditor:
+ meth = self.openPersistentEditor
+ else:
+ meth = self.closePersistentEditor
+
+ curParent = parent
+ children = [model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))]
+
+ columnCount = model.columnCount()
+
+ while len(children) > 0:
+ curParent = children.pop(-1)
+
+ children.extend([model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))])
+
+ for colIdx in range(columnCount):
+ sibling = model.sibling(curParent.row(),
+ colIdx,
+ curParent)
+ item = model.itemFromIndex(sibling)
+ if isinstance(item, SubjectItem) and item.persistent:
+ meth(sibling)
+
+ def rowsAboutToBeRemoved(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsAboutToBeRemoved. Closes all
+ persistent editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsAboutToBeRemoved(parent, start, end)
+
+ def rowsRemoved(self, parent, start, end):
+ """
+ Called when QTreeView.rowsRemoved is emitted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ super(TreeView, self).rowsRemoved(parent, start, end)
+ self.__openPersistentEditors(parent, True)
+
+ def rowsInserted(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsInserted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index
+ :param int end: End index from parent index
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsInserted(parent, start, end)
+ self.__openPersistentEditors(parent)
+
+ def keyReleaseEvent(self, event):
+ """
+ Reimplementation of the QTreeView.keyReleaseEvent.
+ At the moment only Key_Delete is handled. It calls the selected item's
+ queryRemove method, and deleted the item if needed.
+
+ :param qt.QKeyEvent event: A key event
+ """
+
+ # TODO : better filtering
+ key = event.key()
+ modifiers = event.modifiers()
+
+ if key == qt.Qt.Key_Delete and modifiers == qt.Qt.NoModifier:
+ self.__removeIsosurfaces()
+
+ super(TreeView, self).keyReleaseEvent(event)
+
+ def __removeIsosurfaces(self):
+ model = self.model()
+ selected = self.selectedIndexes()
+ items = []
+ # WARNING : the selection mode is set to single, so we re not
+ # supposed to have more than one item here.
+ # Multiple selection deletion has not been tested.
+ # Watch out for index invalidation
+ for index in selected:
+ leftIndex = model.sibling(index.row(), 0, index)
+ leftItem = model.itemFromIndex(leftIndex)
+ if isinstance(leftItem, SubjectItem) and leftItem not in items:
+ items.append(leftItem)
+
+ isos = [item for item in items if isinstance(item, IsoSurfaceRootItem)]
+ if isos:
+ for iso in isos:
+ if iso.queryRemove(self):
+ parentItem = iso.parent()
+ parentItem.removeRow(iso.row())
+ else:
+ qt.QMessageBox.information(
+ self,
+ 'Remove isosurface',
+ 'Select an iso-surface to remove it')
+
+ def __clicked(self, index):
+ """
+ Called when the QTreeView.clicked signal is emitted. Calls the item's
+ leftClick method.
+
+ :param qt.QIndex index: An index
+ """
+ item = self.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem):
+ item.leftClicked()
+
+ def __delegateEvent(self, task):
+ if task == 'remove_iso':
+ self.__removeIsosurfaces()
diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py
new file mode 100644
index 0000000..2eb54a3
--- /dev/null
+++ b/silx/gui/plot3d/ScalarFieldView.py
@@ -0,0 +1,1385 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a window to view a 3D scalar field.
+
+It supports iso-surfaces, a cutting plane and the definition of
+a region of interest.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+import re
+import logging
+import time
+from collections import deque
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot.Colors import rgba
+
+from silx.math.marchingcubes import MarchingCubes
+
+from .scene import axes, cutplane, function, interaction, primitives, transform
+from . import scene
+from .Plot3DWindow import Plot3DWindow
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _BoundedGroup(scene.Group):
+ """Group with data bounds"""
+
+ _shape = None # To provide a default value without overriding __init__
+
+ @property
+ def shape(self):
+ """Data shape (depth, height, width) of this group or None"""
+ return self._shape
+
+ @shape.setter
+ def shape(self, shape):
+ if shape is None:
+ self._shape = None
+ else:
+ depth, height, width = shape
+ self._shape = float(depth), float(height), float(width)
+
+ @property
+ def size(self):
+ """Data size (width, height, depth) of this group or None"""
+ shape = self.shape
+ if shape is None:
+ return None
+ else:
+ return shape[2], shape[1], shape[0]
+
+ @size.setter
+ def size(self, size):
+ if size is None:
+ self.shape = None
+ else:
+ self.shape = size[2], size[1], size[0]
+
+ def _bounds(self, dataBounds=False):
+ if dataBounds and self.size is not None:
+ return numpy.array(((0., 0., 0.), self.size),
+ dtype=numpy.float32)
+ else:
+ return super(_BoundedGroup, self)._bounds(dataBounds)
+
+
+class Isosurface(qt.QObject):
+ """Class representing an iso-surface
+
+ :param parent: The View widget this iso-surface belongs to
+ """
+
+ sigLevelChanged = qt.Signal(float)
+ """Signal emitted when the iso-surface level has changed.
+
+ This signal provides the new level value (might be nan).
+ """
+
+ sigColorChanged = qt.Signal()
+ """Signal emitted when the iso-surface color has changed"""
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the iso-surface visibility has changed.
+
+ This signal provides the new visibility status.
+ """
+
+ def __init__(self, parent):
+ super(Isosurface, self).__init__(parent=parent)
+ self._level = float('nan')
+ self._autoLevelFunction = None
+ self._color = rgba('#FFD700FF')
+ self._data = None
+ self._group = scene.Group()
+
+ def _setData(self, data, copy=True):
+ """Set the data set from which to build the iso-surface.
+
+ :param numpy.ndarray data: The 3D dataset or None
+ :param bool copy: True to make a copy, False to use as is if possible
+ """
+ if data is None:
+ self._data = None
+ else:
+ self._data = numpy.array(data, copy=copy, order='C')
+
+ self._update()
+
+ def _get3DPrimitive(self):
+ """Return the group containing the mesh of the iso-surface if any"""
+ return self._group
+
+ def isVisible(self):
+ """Returns True if iso-surface is visible, else False"""
+ return self._group.visible
+
+ def setVisible(self, visible):
+ """Set the visibility of the iso-surface in the view.
+
+ :param bool visible: True to show the iso-surface, False to hide
+ """
+ visible = bool(visible)
+ if visible != self._group.visible:
+ self._group.visible = visible
+ self.sigVisibilityChanged.emit(visible)
+
+ def getLevel(self):
+ """Return the level of this iso-surface (float)"""
+ return self._level
+
+ def setLevel(self, level):
+ """Set the value at which to build the iso-surface.
+
+ Setting this value reset auto-level function
+
+ :param float level: The value at which to build the iso-surface
+ """
+ self._autoLevelFunction = None
+ level = float(level)
+ if level != self._level:
+ self._level = level
+ self._update()
+ self.sigLevelChanged.emit(level)
+
+ def isAutoLevel(self):
+ """True if iso-level is rebuild for each data set."""
+ return self.getAutoLevelFunction() is not None
+
+ def getAutoLevelFunction(self):
+ """Return the function computing the iso-level (callable or None)"""
+ return self._autoLevelFunction
+
+ def setAutoLevelFunction(self, autoLevel):
+ """Set the function used to compute the iso-level.
+
+ WARNING: The function might get called in a thread.
+
+ :param callable autoLevel:
+ A function taking a 3D numpy.ndarray of float32 and returning
+ a float used as iso-level.
+ Example: numpy.mean(data) + numpy.std(data)
+ """
+ assert callable(autoLevel)
+ self._autoLevelFunction = autoLevel
+ self._update()
+
+ def getColor(self):
+ """Return the color of this iso-surface (QColor)"""
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the color of the iso-surface
+
+ :param color: RGBA color of the isosurface
+ :type color: QColor, str or array-like of 4 float in [0., 1.]
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ if len(self._group.children) != 0:
+ self._group.children[0].setAttribute('color', self._color)
+ self.sigColorChanged.emit()
+
+ def _update(self):
+ """Update underlying mesh"""
+ self._group.children = []
+
+ if self._data is None:
+ if self.isAutoLevel():
+ self._level = float('nan')
+
+ else:
+ if self.isAutoLevel():
+ st = time.time()
+ try:
+ level = float(self.getAutoLevelFunction()(self._data))
+
+ except Exception:
+ module = self.getAutoLevelFunction().__module__
+ name = self.getAutoLevelFunction().__name__
+ _logger.error(
+ "Error while executing iso level function %s.%s",
+ module,
+ name,
+ exc_info=True)
+ level = float('nan')
+
+ else:
+ _logger.info(
+ 'Computed iso-level in %f s.', time.time() - st)
+
+ if level != self._level:
+ self._level = level
+ self.sigLevelChanged.emit(level)
+
+ if numpy.isnan(self._level):
+ return
+
+ st = time.time()
+ vertices, normals, indices = MarchingCubes(
+ self._data,
+ isolevel=self._level)
+ _logger.info('Computed iso-surface in %f s.', time.time() - st)
+
+ if len(vertices) == 0:
+ return
+ else:
+ mesh = primitives.Mesh3D(vertices,
+ colors=self._color,
+ normals=normals,
+ mode='triangles',
+ indices=indices)
+ self._group.children = [mesh]
+
+
+class Colormap(object):
+ """Description of a colormap
+
+ :param str name: Name of the colormap
+ :param str norm: Normalization: 'linear' (default) or 'log'
+ :param float vmin:
+ Lower bound of the colormap or None for autoscale (default)
+ :param float vmax:
+ Upper bounds of the colormap or None for autoscale (default)
+ """
+
+ def __init__(self, name, norm='linear', vmin=None, vmax=None):
+ assert name in function.Colormap.COLORMAPS
+ self._name = str(name)
+
+ assert norm in ('linear', 'log')
+ self._norm = str(norm)
+
+ self._vmin = float(vmin) if vmin is not None else None
+ self._vmax = float(vmax) if vmax is not None else None
+
+ def isAutoscale(self):
+ """True if both min and max are in autoscale mode"""
+ return self._vmin is None or self._vmax is None
+
+ def getName(self):
+ """Return the name of the colormap (str)"""
+ return self._name
+
+ def getNorm(self):
+ """Return the normalization of the colormap (str)"""
+ return self._norm
+
+ def getVMin(self):
+ """Return the lower bound of the colormap or None"""
+ return self._vmin
+
+ def getVMax(self):
+ """Return the upper bounds of the colormap or None"""
+ return self._vmax
+
+
+class SelectedRegion(object):
+ """Selection of a 3D region aligned with the axis.
+
+ :param arrayRange: Range of the selection in the array
+ ((zmin, zmax), (ymin, ymax), (xmin, xmax))
+ :param translation: Offset from array to data coordinates (ox, oy, oz)
+ :param scale: Scale from array to data coordinates (sx, sy, sz)
+ """
+
+ def __init__(self, arrayRange,
+ translation=(0., 0., 0.),
+ scale=(1., 1., 1.)):
+ self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int)
+ assert self._arrayRange.shape == (3, 2)
+ assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0])
+ self._translation = numpy.array(translation, dtype=numpy.float32)
+ assert self._translation.shape == (3,)
+ self._scale = numpy.array(scale, dtype=numpy.float32)
+ assert self._scale.shape == (3,)
+
+ self._dataRange = (self._translation.reshape(3, -1) +
+ self._arrayRange[::-1] * self._scale.reshape(3, -1))
+
+ def getArrayRange(self):
+ """Returns array ranges of the selection: 3x2 array of int
+
+ :return: A numpy array with ((zmin, zmax), (ymin, ymax), (xmin, xmax))
+ :rtype: numpy.ndarray
+ """
+ return self._arrayRange.copy()
+
+ def getArraySlices(self):
+ """Slices corresponding to the selected range in the array
+
+ :return: A numpy array with (zslice, yslice, zslice)
+ :rtype: numpy.ndarray
+ """
+ return (slice(*self._arrayRange[0]),
+ slice(*self._arrayRange[1]),
+ slice(*self._arrayRange[2]))
+
+ def getDataRange(self):
+ """Range in the data coordinates of the selection: 3x2 array of float
+
+ :return: A numpy array with ((xmin, xmax), (ymin, ymax), (zmin, zmax))
+ :rtype: numpy.ndarray
+ """
+ return self._dataRange.copy()
+
+ def getDataScale(self):
+ """Scale from array to data coordinates: (sx, sy, sz)
+
+ :return: A numpy array with (sx, sy, sz)
+ :rtype: numpy.ndarray
+ """
+ return self._scale.copy()
+
+ def getDataTranslation(self):
+ """Offset from array to data coordinates: (ox, oy, oz)
+
+ :return: A numpy array with (ox, oy, oz)
+ :rtype: numpy.ndarray
+ """
+ return self._translation.copy()
+
+
+class CutPlane(qt.QObject):
+ """Class representing a cutting plane
+
+ :param ScalarFieldView sfView: Widget in which the cut plane is applied.
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the cut visibility has changed.
+
+ This signal provides the new visibility status.
+ """
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when the data this plane is cutting has changed."""
+
+ sigPlaneChanged = qt.Signal()
+ """Signal emitted when the cut plane has moved"""
+
+ sigColormapChanged = qt.Signal(object)
+ """Signal emitted when the colormap has changed
+
+ This signal provides the new colormap.
+ """
+
+ sigInterpolationChanged = qt.Signal(str)
+ """Signal emitted when the cut plane interpolation has changed
+
+ This signal provides the new interpolation mode.
+ """
+
+ def __init__(self, sfView):
+ super(CutPlane, self).__init__(parent=sfView)
+
+ self._colormap = Colormap(
+ name='gray', norm='linear', vmin=None, vmax=None)
+
+ self._dataRange = None
+ self._positiveMin = None
+
+ self._plane = cutplane.CutPlane(normal=(0, 1, 0))
+ self._plane.alpha = 1.
+ self._plane.visible = self._visible = False
+ self._plane.addListener(self._planeChanged)
+ self._plane.plane.addListener(self._planePositionChanged)
+
+ sfView.sigDataChanged.connect(self._sfViewDataChanged)
+
+ def _get3DPrimitive(self):
+ """Return the cut plane scene node"""
+ return self._plane
+
+ def _sfViewDataChanged(self):
+ """Handle data change in the ScalarFieldView this plane belongs to"""
+ self._plane.setData(self.sender().getData(), copy=False)
+ self._dataRange = self.sender().getDataRange()
+ self._positiveMin = None
+ self.sigDataChanged.emit()
+
+ # Update colormap range when autoscale
+ if self.getColormap().isAutoscale():
+ self._updateColormapRange()
+
+ def _planeChanged(self, source, *args, **kwargs):
+ """Handle events from the plane primitive"""
+ # Using _visible for now, until scene as more info in events
+ if source.visible != self._visible:
+ self._visible = source.visible
+ self.sigVisibilityChanged.emit(source.visible)
+
+ def _planePositionChanged(self, source, *args, **kwargs):
+ """Handle update of cut plane position and normal"""
+ if self._plane.visible:
+ self.sigPlaneChanged.emit()
+
+ # Plane position
+
+ def moveToCenter(self):
+ """Move cut plane to center of data set"""
+ self._plane.moveToCenter()
+
+ def isValid(self):
+ """Returns whether the cut plane is defined or not (bool)"""
+ return self._plane.isValid
+
+ def getNormal(self):
+ """Returns the normal of the plane (as a unit vector)
+
+ :return: Normal (nx, ny, nz), vector is 0 if no plane is defined
+ :rtype: numpy.ndarray
+ """
+ return self._plane.plane.normal
+
+ def setNormal(self, normal):
+ """Set the normal of the plane
+
+ :param normal: 3-tuple of float: nx, ny, nz
+ """
+ self._plane.plane.normal = normal
+
+ def getPoint(self):
+ """Returns a point on the plane
+
+ :return: (x, y, z)
+ :rtype: numpy.ndarray
+ """
+ return self._plane.plane.point
+
+ def getParameters(self):
+ """Returns the plane equation parameters: a*x + b*y + c*z + d = 0
+
+ :return: Plane equation parameters: (a, b, c, d)
+ :rtype: numpy.ndarray
+ """
+ return self._plane.plane.parameters
+
+ # Visibility
+
+ def isVisible(self):
+ """Returns True if the plane is visible, False otherwise"""
+ return self._plane.visible
+
+ def setVisible(self, visible):
+ """Set the visibility of the plane
+
+ :param bool visible: True to make plane visible
+ """
+ self._plane.visible = visible
+
+ # Border stroke
+
+ def getStrokeColor(self):
+ """Returns the color of the plane border (QColor)"""
+ return qt.QColor.fromRgbF(*self._plane.color)
+
+ def setStrokeColor(self, color):
+ """Set the color of the plane border.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ self._plane.color = rgba(color)
+
+ # Data
+
+ def getImageData(self):
+ """Returns the data and information corresponding to the cut plane.
+
+ The returned data is not interpolated,
+ it is a slice of the 3D scalar field.
+
+ Image data axes are so that plane normal is towards the point of view.
+
+ :return: An object containing the 2D data slice and information
+ """
+ return _CutPlaneImage(self)
+
+ # Interpolation
+
+ def getInterpolation(self):
+ """Returns the interpolation used to display to cut plane.
+
+ :return: 'nearest' or 'linear'
+ :rtype: str
+ """
+ return self._plane.interpolation
+
+ def setInterpolation(self, interpolation):
+ """Set the interpolation used to display to cut plane
+
+ The default interpolation is 'linear'
+
+ :param str interpolation: 'nearest' or 'linear'
+ """
+ if interpolation != self.getInterpolation():
+ self._plane.interpolation = interpolation
+ self.sigInterpolationChanged.emit(interpolation)
+
+ # Colormap
+
+ # def getAlpha(self):
+ # """Returns the transparency of the plane as a float in [0., 1.]"""
+ # return self._plane.alpha
+
+ # def setAlpha(self, alpha):
+ # """Set the plane transparency.
+ #
+ # :param float alpha: Transparency in [0., 1]
+ # """
+ # self._plane.alpha = alpha
+
+ def getColormap(self):
+ """Returns the colormap set by :meth:`getColormap`.
+
+ :return: The colormap
+ :rtype: Colormap
+ """
+ return self._colormap
+
+ def setColormap(self,
+ name='gray',
+ norm='linear',
+ vmin=None,
+ vmax=None):
+ """Set the colormap to use.
+
+ :param str name: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ :param str norm: Colormap mapping: 'linear' or 'log'.
+ :param float vmin: The minimum value of the range or None for autoscale
+ :param float vmax: The maximum value of the range or None for autoscale
+ """
+ _logger.debug('setColormap %s %s (%s, %s)',
+ name, norm, str(vmin), str(vmax))
+
+ self._colormap = Colormap(
+ name=name, norm=norm, vmin=vmin, vmax=vmax)
+
+ self._updateColormapRange()
+ self.sigColormapChanged.emit(self.getColormap())
+
+ def getColormapEffectiveRange(self):
+ """Returns the currently used range of the colormap.
+
+ This range is computed from the data set if colormap is in autoscale.
+ Range is clipped to positive values when using log scale.
+
+ :return: 2-tuple of float
+ """
+ return self._plane.colormap.range_
+
+ def _updateColormapRange(self):
+ """Update the colormap range"""
+ colormap = self.getColormap()
+
+ self._plane.colormap.name = colormap.getName()
+ if colormap.isAutoscale():
+ range_ = self._dataRange
+ if range_ is None: # No data, use a default range
+ range_ = 1., 10.
+ else:
+ range_ = colormap.getVMin(), colormap.getVMax()
+
+ if colormap.getNorm() == 'linear':
+ self._plane.colormap.norm = 'linear'
+ self._plane.colormap.range_ = range_
+
+ else: # Log
+ # Make sure range is strictly positive
+ if range_[0] <= 0.:
+ data = self._plane.getData(copy=False)
+ if data is not None:
+ if self._positiveMin is None:
+ # TODO compute this with the range as a combo operation
+ self._positiveMin = numpy.min(data[data > 0.])
+ range_ = (self._positiveMin,
+ max(range_[1], self._positiveMin))
+
+ self._plane.colormap.range_ = range_
+ self._plane.colormap.norm = colormap.getNorm()
+
+
+class _CutPlaneImage(object):
+ """Object representing the data sliced by a cut plane
+
+ :param CutPlane cutPlane: The CutPlane from which to generate image info
+ """
+
+ def __init__(self, cutPlane):
+ # Init attributes with default values
+ self._isValid = False
+ self._data = numpy.array([])
+ self._xLabel = ''
+ self._yLabel = ''
+ self._normalLabel = ''
+ self._scale = 1., 1.
+ self._translation = 0., 0.
+ self._index = 0
+ self._position = 0.
+
+ sfView = cutPlane.parent()
+ if not sfView or not cutPlane.isValid():
+ _logger.info("No plane available")
+ return
+
+ data = sfView.getData(copy=False)
+ if data is None:
+ _logger.info("No data available")
+ return
+
+ normal = cutPlane.getNormal()
+ point = numpy.array(cutPlane.getPoint(), dtype=numpy.int)
+
+ if numpy.all(numpy.equal(normal, (1., 0., 0.))):
+ index = max(0, min(point[0], data.shape[2] - 1))
+ slice_ = data[:, :, index]
+ xAxisIndex, yAxisIndex, normalAxisIndex = 1, 2, 0 # y, z, x
+ elif numpy.all(numpy.equal(normal, (0., 1., 0.))):
+ index = max(0, min(point[1], data.shape[1] - 1))
+ slice_ = numpy.transpose(data[:, index, :])
+ xAxisIndex, yAxisIndex, normalAxisIndex = 2, 0, 1 # z, x, y
+ elif numpy.all(numpy.equal(normal, (0., 0., 1.))):
+ index = max(0, min(point[2], data.shape[0] - 1))
+ slice_ = data[index, :, :]
+ xAxisIndex, yAxisIndex, normalAxisIndex = 0, 1, 2 # x, y, z
+ else:
+ _logger.warning('Unsupported normal: (%f, %f, %f)',
+ normal[0], normal[1], normal[2])
+ return
+
+ # Store cut plane image info
+
+ self._isValid = True
+ self._data = numpy.array(slice_, copy=True)
+
+ labels = sfView.getAxesLabels()
+ scale = sfView.getScale()
+ translation = sfView.getTranslation()
+
+ self._xLabel = labels[xAxisIndex]
+ self._yLabel = labels[yAxisIndex]
+ self._normalLabel = labels[normalAxisIndex]
+
+ self._scale = scale[xAxisIndex], scale[yAxisIndex]
+ self._translation = translation[xAxisIndex], translation[yAxisIndex]
+
+ self._index = index
+ self._position = float(index * scale[normalAxisIndex] +
+ translation[normalAxisIndex])
+
+ def isValid(self):
+ """Returns True if the cut plane image is defined (bool)"""
+ return self._isValid
+
+ def getData(self, copy=True):
+ """Returns the image data sliced by the cut plane.
+
+ :param bool copy: True to get a copy, False otherwise
+ :return: The 2D image data corresponding to the cut plane
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def getXLabel(self):
+ """Returns the label associated to the X axis of the image (str)"""
+ return self._xLabel
+
+ def getYLabel(self):
+ """Returns the label associated to the Y axis of the image (str)"""
+ return self._yLabel
+
+ def getNormalLabel(self):
+ """Returns the label of the 3D axis of the plane normal (str)"""
+ return self._normalLabel
+
+ def getScale(self):
+ """Returns the scales of the data as a 2-tuple of float (sx, sy)"""
+ return self._scale
+
+ def getTranslation(self):
+ """Returns the offset of the data as a 2-tuple of float (ox, oy)"""
+ return self._translation
+
+ def getIndex(self):
+ """Returns the index in the data array of the cut plane (int)"""
+ return self._index
+
+ def getPosition(self):
+ """Returns the cut plane position along the normal axis (flaot)"""
+ return self._position
+
+
+class ScalarFieldView(Plot3DWindow):
+ """Widget computing and displaying an iso-surface from a 3D scalar dataset.
+
+ Limitation: Currently, iso-surfaces are generated with higher values
+ than the iso-level 'inside' the surface.
+
+ :param parent: See :class:`QMainWindow`
+ """
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when the scalar data field has changed."""
+
+ sigSelectedRegionChanged = qt.Signal(object)
+ """Signal emitted when the selected region has changed.
+
+ This signal provides the new selected region.
+ """
+
+ def __init__(self, parent=None):
+ super(ScalarFieldView, self).__init__(parent)
+ self._colormap = Colormap(
+ name='gray', norm='linear', vmin=None, vmax=None)
+ self._selectedRange = None
+
+ # Store iso-surfaces
+ self._isosurfaces = []
+
+ # Transformations
+ self._dataScale = transform.Scale()
+ self._dataTranslate = transform.Translate()
+
+ self._foregroundColor = 1., 1., 1., 1.
+ self._highlightColor = 0.7, 0.7, 0., 1.
+
+ self._data = None
+ self._dataRange = None
+
+ self._group = _BoundedGroup()
+ self._group.transforms = [self._dataTranslate, self._dataScale]
+
+ self._selectionBox = primitives.Box()
+ self._selectionBox.strokeSmooth = False
+ self._selectionBox.strokeWidth = 1.
+ # self._selectionBox.fillColor = 1., 1., 1., 0.3
+ # self._selectionBox.fillCulling = 'back'
+ self._selectionBox.visible = False
+ self._group.children.append(self._selectionBox)
+
+ self._cutPlane = CutPlane(sfView=self)
+ self._cutPlane.sigVisibilityChanged.connect(
+ self._planeVisibilityChanged)
+ self._group.children.append(self._cutPlane._get3DPrimitive())
+
+ self._isogroup = primitives.GroupDepthOffset()
+ self._isogroup.transforms = [
+ # Convert from z, y, x from marching cubes to x, y, z
+ transform.Matrix((
+ (0., 0., 1., 0.),
+ (0., 1., 0., 0.),
+ (1., 0., 0., 0.),
+ (0., 0., 0., 1.))),
+ # Offset to match cutting plane coords
+ transform.Translate(0.5, 0.5, 0.5)
+ ]
+ self._group.children.append(self._isogroup)
+
+ self._bbox = axes.LabelledAxes()
+ self._bbox.children = [self._group]
+ self.getPlot3DWidget().viewport.scene.children.append(self._bbox)
+
+ self._initInteractionToolBar()
+
+ self._updateColors()
+
+ self.getPlot3DWidget().viewport.light.shininess = 32
+
+ def saveConfig(self, ioDevice):
+ """
+ Saves this view state. Only isosurfaces at the moment. Does not save
+ the isosurface's function.
+
+ :param qt.QIODevice ioDevice: A `qt.QIODevice`.
+ """
+
+ stream = qt.QDataStream(ioDevice)
+
+ stream.writeString('<ScalarFieldView>')
+
+ isoSurfaces = self.getIsosurfaces()
+
+ nIsoSurfaces = len(isoSurfaces)
+
+ # TODO : delegate the serialization to the serialized items
+ # isosurfaces
+ if nIsoSurfaces:
+ tagIn = '<IsoSurfaces nIso={0}>'.format(nIsoSurfaces)
+ stream.writeString(tagIn)
+
+ for surface in isoSurfaces:
+ color = surface.getColor()
+ level = surface.getLevel()
+ visible = surface.isVisible()
+ stream << color
+ stream.writeDouble(level)
+ stream.writeBool(visible)
+
+ stream.writeString('</IsoSurfaces>')
+
+ stream.writeString('<Style>')
+ background = self.getBackgroundColor()
+ foreground = self.getForegroundColor()
+ highlight = self.getHighlightColor()
+ stream << background << foreground << highlight
+ stream.writeString('</Style>')
+
+ stream.writeString('</ScalarFieldView>')
+
+ def loadConfig(self, ioDevice):
+ """
+ Loads this view state.
+ See ScalarFieldView.saveView to know what is supported at the moment.
+
+ :param qt.QIODevice ioDevice: A `qt.QIODevice`.
+ """
+
+ tagStack = deque()
+
+ tagInRegex = re.compile('<(?P<itemId>[^ /]*) *'
+ '(?P<args>.*)>')
+
+ tagOutRegex = re.compile('</(?P<itemId>[^ ]*)>')
+
+ tagRootInRegex = re.compile('<ScalarFieldView>')
+
+ isoSurfaceArgsRegex = re.compile('nIso=(?P<nIso>[0-9]*)')
+
+ stream = qt.QDataStream(ioDevice)
+
+ tag = stream.readString()
+ tagMatch = tagRootInRegex.match(tag)
+
+ if tagMatch is None:
+ # TODO : explicit error
+ raise ValueError('Unknown data.')
+
+ itemId = 'ScalarFieldView'
+
+ tagStack.append(itemId)
+
+ while True:
+
+ tag = stream.readString()
+
+ tagMatch = tagOutRegex.match(tag)
+ if tagMatch:
+ closeId = tagMatch.groupdict()['itemId']
+ if closeId != itemId:
+ # TODO : explicit error
+ raise ValueError('Unexpected closing tag {0} '
+ '(expected {1})'
+ ''.format(closeId, itemId))
+
+ if itemId == 'ScalarFieldView':
+ # reached end
+ break
+ else:
+ itemId = tagStack.pop()
+ # fetching next tag
+ continue
+
+ tagMatch = tagInRegex.match(tag)
+
+ if tagMatch is None:
+ # TODO : explicit error
+ raise ValueError('Unknown data.')
+
+ tagStack.append(itemId)
+
+ matchDict = tagMatch.groupdict()
+
+ itemId = matchDict['itemId']
+
+ # TODO : delegate the deserialization to the serialized items
+ if itemId == 'IsoSurfaces':
+ argsMatch = isoSurfaceArgsRegex.match(matchDict['args'])
+ if not argsMatch:
+ # TODO : explicit error
+ raise ValueError('Failed to parse args "{0}".'
+ ''.format(matchDict['args']))
+ argsDict = argsMatch.groupdict()
+ nIso = int(argsDict['nIso'])
+ if nIso:
+ for surface in self.getIsosurfaces():
+ self.removeIsosurface(surface)
+ for isoIdx in range(nIso):
+ color = qt.QColor()
+ stream >> color
+ level = stream.readDouble()
+ visible = stream.readBool()
+ surface = self.addIsosurface(level, color=color)
+ surface.setVisible(visible)
+ elif itemId == 'Style':
+ background = qt.QColor()
+ foreground = qt.QColor()
+ highlight = qt.QColor()
+ stream >> background >> foreground >> highlight
+ self.setBackgroundColor(background)
+ self.setForegroundColor(foreground)
+ self.setHighlightColor(highlight)
+ else:
+ raise ValueError('Unknown entry tag {0}.'
+ ''.format(itemId))
+
+ def _initInteractionToolBar(self):
+ self._interactionToolbar = qt.QToolBar()
+ self._interactionToolbar.setEnabled(False)
+
+ group = qt.QActionGroup(self._interactionToolbar)
+ group.setExclusive(True)
+
+ self._cameraAction = qt.QAction(None)
+ self._cameraAction.setText('camera')
+ self._cameraAction.setCheckable(True)
+ self._cameraAction.setToolTip('Control camera')
+ self._cameraAction.setChecked(True)
+ group.addAction(self._cameraAction)
+
+ self._planeAction = qt.QAction(None)
+ self._planeAction.setText('plane')
+ self._planeAction.setCheckable(True)
+ self._planeAction.setToolTip('Control cutting plane')
+ group.addAction(self._planeAction)
+ group.triggered.connect(self._interactionChanged)
+
+ self._interactionToolbar.addActions(group.actions())
+ self.addToolBar(self._interactionToolbar)
+
+ def _planeVisibilityChanged(self, visible):
+ """Handle visibility events from the plane"""
+ if visible != self._interactionToolbar.isEnabled():
+ if visible:
+ self._interactionToolbar.setEnabled(True)
+ self.setInteractiveMode('plane')
+ else:
+ self._interactionToolbar.setEnabled(False)
+ self.setInteractiveMode('camera')
+
+ def _interactionChanged(self, action):
+ self.setInteractiveMode(action.text())
+
+ def setInteractiveMode(self, mode):
+ """Choose the current interaction.
+
+ :param str mode: Either plane or camera
+ """
+ if mode == self.getInteractiveMode():
+ return
+
+ sceneScale = self.getPlot3DWidget().viewport.scene.transforms[0]
+ if mode == 'plane':
+ self.getPlot3DWidget().eventHandler = \
+ interaction.PanPlaneRotateCameraControl(
+ self.getPlot3DWidget().viewport,
+ self._cutPlane._get3DPrimitive(),
+ mode='position',
+ scaleTransform=sceneScale)
+ self._planeAction.setChecked(True)
+ elif mode == 'camera':
+ self.getPlot3DWidget().eventHandler = interaction.CameraControl(
+ self.getPlot3DWidget().viewport, orbitAroundCenter=False,
+ mode='position', scaleTransform=sceneScale,
+ selectCB=None)
+ self._cameraAction.setChecked(True)
+ else:
+ raise ValueError('Unsupported interactive mode %s', str(mode))
+ self._updateColors()
+
+ def getInteractiveMode(self):
+ """Returns the current interaction mode, see :meth:`setInteractiveMode`
+ """
+ if isinstance(self.getPlot3DWidget().eventHandler,
+ interaction.PanPlaneRotateCameraControl):
+ return 'plane'
+ elif isinstance(self.getPlot3DWidget().eventHandler,
+ interaction.CameraControl):
+ return 'camera'
+ else:
+ raise RuntimeError('Unknown interactive mode')
+
+ # Handle scalar field
+
+ def setData(self, data, copy=True):
+ """Set the 3D scalar data set to use for building the iso-surface.
+
+ Dataset order is zyx (i.e., first dimension is z).
+
+ :param data: scalar field from which to extract the iso-surface
+ :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2)
+ :param bool copy:
+ True (default) to make a copy,
+ False to avoid copy (DO NOT MODIFY data afterwards)
+ """
+ if data is None:
+ self._data = None
+ self._dataRange = None
+ self.setSelectedRegion(zrange=None, yrange=None, xrange_=None)
+ self._group.shape = None
+ self.centerScene()
+
+ else:
+ data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C')
+ assert data.ndim == 3
+ assert min(data.shape) >= 2
+
+ wasData = self._data is not None
+ previousSelectedRegion = self.getSelectedRegion()
+
+ self._data = data
+ self._dataRange = self._data.min(), self._data.max()
+
+ if previousSelectedRegion is not None:
+ # Update selected region to ensure it is clipped to array range
+ self.setSelectedRegion(*previousSelectedRegion.getArrayRange())
+
+ self._group.shape = self._data.shape
+
+ if not wasData:
+ self.centerScene() # Reset viewpoint the first time only
+
+ # Update iso-surfaces
+ for isosurface in self.getIsosurfaces():
+ isosurface._setData(self._data, copy=False)
+
+ self.sigDataChanged.emit()
+
+ def getData(self, copy=True):
+ """Get the 3D scalar data currently used to build the iso-surface.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :return: The data set (or None if not set)
+ """
+ if self._data is None:
+ return None
+ else:
+ return numpy.array(self._data, copy=copy)
+
+ def getDataRange(self):
+ """Return the range of the data as a 2-tuple (min, max)"""
+ return self._dataRange
+
+ # Transformations
+
+ def setScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale of the 3D scalar field (i.e., size of a voxel).
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ scale = numpy.array((sx, sy, sz), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(scale, self.getScale())):
+ self._dataScale.scale = scale
+ self.centerScene() # Reset viewpoint
+
+ def getScale(self):
+ """Returns the scales provided by :meth:`setScale` as a numpy.ndarray.
+ """
+ return self._dataScale.scale
+
+ def setTranslation(self, x=0., y=0., z=0.):
+ """Set the translation of the origin of the data array in data coordinates.
+
+ :param float x: Offset of the data origin on the X axis
+ :param float y: Offset of the data origin on the Y axis
+ :param float z: Offset of the data origin on the Z axis
+ """
+ translation = numpy.array((x, y, z), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(translation, self.getTranslation())):
+ self._dataTranslate.translation = translation
+ self.centerScene() # Reset viewpoint
+
+ def getTranslation(self):
+ """Returns the offset set by :meth:`setTranslation` as a numpy.ndarray.
+ """
+ return self._dataTranslate.translation
+
+ # Axes labels
+
+ def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
+ """Set the text labels of the axes.
+
+ :param str xlabel: Label of the X axis, None to leave unchanged.
+ :param str ylabel: Label of the Y axis, None to leave unchanged.
+ :param str zlabel: Label of the Z axis, None to leave unchanged.
+ """
+ if xlabel is not None:
+ self._bbox.xlabel = xlabel
+
+ if ylabel is not None:
+ self._bbox.ylabel = ylabel
+
+ if zlabel is not None:
+ self._bbox.zlabel = zlabel
+
+ class _Labels(tuple):
+ """Return type of :meth:`getAxesLabels`"""
+
+ def getXLabel(self):
+ """Label of the X axis (str)"""
+ return self[0]
+
+ def getYLabel(self):
+ """Label of the Y axis (str)"""
+ return self[1]
+
+ def getZLabel(self):
+ """Label of the Z axis (str)"""
+ return self[2]
+
+ def getAxesLabels(self):
+ """Returns the text labels of the axes
+
+ >>> widget = ScalarFieldView()
+ >>> widget.setAxesLabels(xlabel='X')
+
+ You can get the labels either as a 3-tuple:
+
+ >>> xlabel, ylabel, zlabel = widget.getAxesLabels()
+
+ Or as an object with methods getXLabel, getYLabel and getZLabel:
+
+ >>> labels = widget.getAxesLabels()
+ >>> labels.getXLabel()
+ ... 'X'
+
+ :return: object describing the labels
+ """
+ return self._Labels((self._bbox.xlabel,
+ self._bbox.ylabel,
+ self._bbox.zlabel))
+
+ # Colors
+
+ def _updateColors(self):
+ """Update item depending on foreground/highlight color"""
+ self._bbox.tickColor = self._foregroundColor
+ self._selectionBox.strokeColor = self._foregroundColor
+ if self.getInteractiveMode() == 'plane':
+ self._cutPlane.setStrokeColor(self._highlightColor)
+ self._bbox.color = self._foregroundColor
+ else:
+ self._cutPlane.setStrokeColor(self._foregroundColor)
+ self._bbox.color = self._highlightColor
+
+ def getForegroundColor(self):
+ """Return color used for text and bounding box (QColor)"""
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._foregroundColor:
+ self._foregroundColor = color
+ self._updateColors()
+
+ def getHighlightColor(self):
+ """Return color used for highlighted item bounding box (QColor)"""
+ return qt.QColor.fromRgbF(*self._highlightColor)
+
+ def setHighlightColor(self, color):
+ """Set hightlighted item color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._highlightColor:
+ self._highlightColor = color
+ self._updateColors()
+
+ # Cut Plane
+
+ def getCutPlanes(self):
+ """Return an iterable of all cut planes of the view.
+
+ This includes hidden cut planes.
+
+ For now, there is always one cut plane.
+ """
+ return (self._cutPlane,)
+
+ # Selection
+
+ def setSelectedRegion(self, zrange=None, yrange=None, xrange_=None):
+ """Set the 3D selected region aligned with the axes.
+
+ Provided range are array indices range.
+ The provided ranges are clipped to the data.
+ If a range is None, the range of the array on this dimension is used.
+
+ :param zrange: (zmin, zmax) range of the selection
+ :param yrange: (ymin, ymax) range of the selection
+ :param xrange_: (xmin, xmax) range of the selection
+ """
+ # No range given: unset selection
+ if zrange is None and yrange is None and xrange_ is None:
+ selectedRange = None
+
+ else:
+ # Handle default ranges
+ if self._data is not None:
+ if zrange is None:
+ zrange = 0, self._data.shape[0]
+ if yrange is None:
+ yrange = 0, self._data.shape[1]
+ if xrange_ is None:
+ xrange_ = 0, self._data.shape[2]
+
+ elif None in (xrange_, yrange, zrange):
+ # One of the range is None and no data available
+ raise RuntimeError(
+ 'Data is not set, cannot get default range from it.')
+
+ # Clip selected region to data shape and make sure min <= max
+ selectedRange = numpy.array((
+ (max(0, min(*zrange)),
+ min(self._data.shape[0], max(*zrange))),
+ (max(0, min(*yrange)),
+ min(self._data.shape[1], max(*yrange))),
+ (max(0, min(*xrange_)),
+ min(self._data.shape[2], max(*xrange_))),
+ ), dtype=numpy.int)
+
+ # numpy.equal supports None
+ if not numpy.all(numpy.equal(selectedRange, self._selectedRange)):
+ self._selectedRange = selectedRange
+
+ # Update scene accordingly
+ if self._selectedRange is None:
+ self._selectionBox.visible = False
+ else:
+ self._selectionBox.visible = True
+ scales = self._selectedRange[:, 1] - self._selectedRange[:, 0]
+ self._selectionBox.size = scales[::-1]
+ self._selectionBox.transforms = [
+ transform.Translate(*self._selectedRange[::-1, 0])]
+
+ self.sigSelectedRegionChanged.emit(self.getSelectedRegion())
+
+ def getSelectedRegion(self):
+ """Returns the currently selected region or None."""
+ if self._selectedRange is None:
+ return None
+ else:
+ return SelectedRegion(self._selectedRange,
+ translation=self.getTranslation(),
+ scale=self.getScale())
+
+ # Handle iso-surfaces
+
+ sigIsosurfaceAdded = qt.Signal(object)
+ """Signal emitted when a new iso-surface is added to the view.
+
+ The newly added iso-surface is provided by this signal
+ """
+
+ sigIsosurfaceRemoved = qt.Signal(object)
+ """Signal emitted when an iso-surface is removed from the view
+
+ The removed iso-surface is provided by this signal.
+ """
+
+ def addIsosurface(self, level, color):
+ """Add an iso-surface to the view.
+
+ :param level:
+ The value at which to build the iso-surface or a callable
+ (e.g., a function) taking a 3D numpy.ndarray as input and
+ returning a float.
+ Example: numpy.mean(data) + numpy.std(data)
+ :type level: float or callable
+ :param color: RGBA color of the isosurface
+ :type color: str or array-like of 4 float in [0., 1.]
+ :return: Isosurface object describing this isosurface
+ """
+ isosurface = Isosurface(parent=self)
+ isosurface.setColor(color)
+ if callable(level):
+ isosurface.setAutoLevelFunction(level)
+ else:
+ isosurface.setLevel(level)
+ isosurface._setData(self._data, copy=False)
+ isosurface.sigLevelChanged.connect(self._updateIsosurfaces)
+
+ self._isosurfaces.append(isosurface)
+
+ self._updateIsosurfaces()
+
+ self.sigIsosurfaceAdded.emit(isosurface)
+ return isosurface
+
+ def getIsosurfaces(self):
+ """Return an iterable of all iso-surfaces of the view"""
+ return tuple(self._isosurfaces)
+
+ def removeIsosurface(self, isosurface):
+ """Remove an iso-surface from the view.
+
+ :param isosurface: The isosurface object to remove"""
+ if isosurface not in self.getIsosurfaces():
+ _logger.warning(
+ "Try to remove isosurface that is not in the list: %s",
+ str(isosurface))
+ else:
+ isosurface.sigLevelChanged.disconnect(self._updateIsosurfaces)
+ self._isosurfaces.remove(isosurface)
+ self._updateIsosurfaces()
+ self.sigIsosurfaceRemoved.emit(isosurface)
+
+ def clearIsosurfaces(self):
+ """Remove all iso-surfaces from the view."""
+ for isosurface in self.getIsosurfaces():
+ self.removeIsosurface(isosurface)
+
+ def _updateIsosurfaces(self, level=None):
+ """Handle updates of iso-surfaces level and add/remove"""
+ # Sorting using minus, this supposes data 'object' to be max values
+ sortedIso = sorted(self.getIsosurfaces(),
+ key=lambda iso: - iso.getLevel())
+ self._isogroup.children = [iso._get3DPrimitive() for iso in sortedIso]
diff --git a/silx/gui/plot3d/ViewpointToolBar.py b/silx/gui/plot3d/ViewpointToolBar.py
new file mode 100644
index 0000000..d062c1b
--- /dev/null
+++ b/silx/gui/plot3d/ViewpointToolBar.py
@@ -0,0 +1,114 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a toolbar to control Plot3DWidget viewpoint."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/09/2016"
+
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+
+
+class ViewpointActionGroup(qt.QActionGroup):
+ """ActionGroup of actions to reset the viewpoint.
+
+ As for QActionGroup, add group's actions to the widget with:
+ `widget.addActions(actionGroup.actions())`
+
+ :param Plot3DWidget plot3D: The widget for which to control the viewpoint
+ :param parent: See :class:`QActionGroup`
+ """
+
+ # Action information: icon name, text, tooltip
+ _RESET_CAMERA_ACTIONS = (
+ ('cube-front', 'Front', 'View along the -Z axis'),
+ ('cube-back', 'Back', 'View along the +Z axis'),
+ ('cube-top', 'Top', 'View along the -Y'),
+ ('cube-bottom', 'Bottom', 'View along the +Y'),
+ ('cube-right', 'Right', 'View along the -X'),
+ ('cube-left', 'Left', 'View along the +X'),
+ ('cube', 'Side', 'Side view')
+ )
+
+ def __init__(self, plot3D, parent=None):
+ super(ViewpointActionGroup, self).__init__(parent)
+ self.setExclusive(False)
+
+ self._plot3D = plot3D
+
+ for actionInfo in self._RESET_CAMERA_ACTIONS:
+ iconname, text, tooltip = actionInfo
+
+ action = qt.QAction(getQIcon(iconname), text, None)
+ action.setCheckable(False)
+ action.setToolTip(tooltip)
+ self.addAction(action)
+
+ self.triggered[qt.QAction].connect(self._actionGroupTriggered)
+
+ def _actionGroupTriggered(self, action):
+ actionname = action.text().lower()
+
+ self._plot3D.viewport.camera.extrinsic.reset(face=actionname)
+ self._plot3D.centerScene()
+
+
+class ViewpointToolBar(qt.QToolBar):
+ """A toolbar providing icons to reset the viewpoint.
+
+ :param parent: See :class:`QToolBar`
+ :param Plot3DWidget plot3D: The widget to control
+ :param str title: Title of the toolbar
+ """
+
+ def __init__(self, parent=None, plot3D=None, title='Viewpoint control'):
+ super(ViewpointToolBar, self).__init__(title, parent)
+
+ self._actionGroup = ViewpointActionGroup(plot3D)
+ assert plot3D is not None
+ self._plot3D = plot3D
+ self.addActions(self._actionGroup.actions())
+
+ # Choosing projection disabled for now
+ # Add projection combo box
+ # comboBoxProjection = qt.QComboBox()
+ # comboBoxProjection.addItem('Perspective')
+ # comboBoxProjection.addItem('Parallel')
+ # comboBoxProjection.setToolTip(
+ # 'Choose the projection:'
+ # ' perspective or parallel (i.e., orthographic)')
+ # comboBoxProjection.currentIndexChanged[(str)].connect(
+ # self._comboBoxProjectionCurrentIndexChanged)
+ # self.addWidget(qt.QLabel('Projection:'))
+ # self.addWidget(comboBoxProjection)
+
+ # def _comboBoxProjectionCurrentIndexChanged(self, text):
+ # """Projection combo box listener"""
+ # self._plot3D.setProjection(
+ # 'perspective' if text == 'Perspective' else 'orthographic')
diff --git a/silx/gui/plot3d/__init__.py b/silx/gui/plot3d/__init__.py
new file mode 100644
index 0000000..ad45424
--- /dev/null
+++ b/silx/gui/plot3d/__init__.py
@@ -0,0 +1,45 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This package provides widgets displaying 3D content based on OpenGL.
+
+It depends on PyOpenGL and QtOpenGL.
+"""
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/01/2017"
+
+
+from .. import qt as _qt
+
+if not _qt.HAS_OPENGL:
+ raise ImportError('Qt.QtOpenGL is not available')
+
+try:
+ import OpenGL as _OpenGL
+except ImportError:
+ raise ImportError('PyOpenGL is not installed')
diff --git a/silx/gui/plot3d/scene/__init__.py b/silx/gui/plot3d/scene/__init__.py
new file mode 100644
index 0000000..25a7171
--- /dev/null
+++ b/silx/gui/plot3d/scene/__init__.py
@@ -0,0 +1,34 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a 3D graphics scene graph structure."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/11/2016"
+
+
+from .core import Elem, Group, PrivateGroup # noqa
+from .viewport import Viewport # noqa
+from .window import Window # noqa
diff --git a/silx/gui/plot3d/scene/axes.py b/silx/gui/plot3d/scene/axes.py
new file mode 100644
index 0000000..528e4f7
--- /dev/null
+++ b/silx/gui/plot3d/scene/axes.py
@@ -0,0 +1,224 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Primitive displaying a text field in the scene."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/10/2016"
+
+
+import logging
+import numpy
+
+from ...plot._utils import ticklayout
+
+from . import core, primitives, text, transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+class LabelledAxes(primitives.GroupBBox):
+ """A group displaying a bounding box with axes labels around its children.
+ """
+
+ def __init__(self):
+ super(LabelledAxes, self).__init__()
+ self._ticksForBounds = None
+
+ self._font = text.Font()
+
+ # TODO offset labels from anchor in pixels
+
+ self._xlabel = text.Text2D(font=self._font)
+ self._xlabel.align = 'center'
+ self._xlabel.transforms = [self._boxTransforms,
+ transform.Translate(tx=0.5)]
+ self._children.insert(-1, self._xlabel)
+
+ self._ylabel = text.Text2D(font=self._font)
+ self._ylabel.align = 'center'
+ self._ylabel.transforms = [self._boxTransforms,
+ transform.Translate(ty=0.5)]
+ self._children.insert(-1, self._ylabel)
+
+ self._zlabel = text.Text2D(font=self._font)
+ self._zlabel.align = 'center'
+ self._zlabel.transforms = [self._boxTransforms,
+ transform.Translate(tz=0.5)]
+ self._children.insert(-1, self._zlabel)
+
+ # Init tick lines with dummy pos
+ self._tickLines = primitives.DashedLines(
+ positions=((0., 0., 0.), (0., 0., 0.)))
+ self._tickLines.dash = 5, 10
+ self._tickLines.visible = False
+ self._children.insert(-1, self._tickLines)
+
+ self._tickLabels = core.Group()
+ self._children.insert(-1, self._tickLabels)
+
+ # Sync color
+ self.tickColor = 1., 1., 1., 1.
+
+ @property
+ def tickColor(self):
+ """Color of ticks and text labels.
+
+ This does NOT set bounding box color.
+ Use :attr:`color` for the bounding box.
+ """
+ return self._xlabel.foreground
+
+ @tickColor.setter
+ def tickColor(self, color):
+ self._xlabel.foreground = color
+ self._ylabel.foreground = color
+ self._zlabel.foreground = color
+ transparentColor = color[0], color[1], color[2], color[3] * 0.6
+ self._tickLines.setAttribute('color', transparentColor)
+ for label in self._tickLabels.children:
+ label.foreground = color
+
+ @property
+ def font(self):
+ """Font of axes text labels (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font = font
+ self._xlabel.font = font
+ self._ylabel.font = font
+ self._zlabel.font = font
+ for label in self._tickLabels.children:
+ label.font = font
+
+ @property
+ def xlabel(self):
+ """Text label of the X axis (str)"""
+ return self._xlabel.text
+
+ @xlabel.setter
+ def xlabel(self, text):
+ self._xlabel.text = text
+
+ @property
+ def ylabel(self):
+ """Text label of the Y axis (str)"""
+ return self._ylabel.text
+
+ @ylabel.setter
+ def ylabel(self, text):
+ self._ylabel.text = text
+
+ @property
+ def zlabel(self):
+ """Text label of the Z axis (str)"""
+ return self._zlabel.text
+
+ @zlabel.setter
+ def zlabel(self, text):
+ self._zlabel.text = text
+
+ def _updateTicks(self):
+ """Check if ticks need update and update them if needed."""
+ bounds = self._group.bounds(transformed=False, dataBounds=True)
+ if bounds is None: # No content
+ if self._ticksForBounds is not None:
+ self._ticksForBounds = None
+ self._tickLines.visible = False
+ self._tickLabels.children = [] # Reset previous labels
+
+ elif (self._ticksForBounds is None or
+ not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ self._ticksForBounds = bounds
+
+ # Update ticks
+ ticklength = numpy.abs(bounds[1] - bounds[0])
+
+ xticks, xlabels = ticklayout.ticks(*bounds[:, 0])
+ yticks, ylabels = ticklayout.ticks(*bounds[:, 1])
+ zticks, zlabels = ticklayout.ticks(*bounds[:, 2])
+
+ # Update tick lines
+ coords = numpy.empty(
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
+ dtype=numpy.float32)
+ coords[:, :, :] = bounds[0, :] # account for offset from origin
+
+ xcoords = coords[:len(xticks)]
+ xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
+ xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
+ xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
+
+ ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
+ ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
+ ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
+
+ zcoords = coords[len(xticks) + len(yticks):]
+ zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
+ zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
+ zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
+
+ self._tickLines.setPositions(coords.reshape(-1, 3))
+ self._tickLines.visible = True
+
+ # Update labels
+ color = self.tickColor
+ offsets = bounds[0] - ticklength / 20.
+ labels = []
+ for tick, label in zip(xticks, xlabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=tick, ty=offsets[1], tz=offsets[2])]
+ labels.append(text2d)
+
+ for tick, label in zip(yticks, ylabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=offsets[0], ty=tick, tz=offsets[2])]
+ labels.append(text2d)
+
+ for tick, label in zip(zticks, zlabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=offsets[0], ty=offsets[1], tz=tick)]
+ labels.append(text2d)
+
+ self._tickLabels.children = labels # Reset previous labels
+
+ def prepareGL2(self, context):
+ self._updateTicks()
+ super(LabelledAxes, self).prepareGL2(context)
diff --git a/silx/gui/plot3d/scene/camera.py b/silx/gui/plot3d/scene/camera.py
new file mode 100644
index 0000000..8cc279d
--- /dev/null
+++ b/silx/gui/plot3d/scene/camera.py
@@ -0,0 +1,350 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides classes to handle a perspective projection in 3D."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import numpy
+
+from . import transform
+
+
+# CameraExtrinsic #############################################################
+
+class CameraExtrinsic(transform.Transform):
+ """Transform matrix to handle camera position and orientation.
+
+ :param position: Coordinates of the point of view.
+ :type position: numpy.ndarray-like of 3 float32.
+ :param direction: Sight direction vector.
+ :type direction: numpy.ndarray-like of 3 float32.
+ :param up: Vector pointing upward in the image plane.
+ :type up: numpy.ndarray-like of 3 float32.
+ """
+
+ def __init__(self, position=(0., 0., 0.),
+ direction=(0., 0., -1.),
+ up=(0., 1., 0.)):
+
+ super(CameraExtrinsic, self).__init__()
+ self._position = None
+ self.position = position # set _position
+ self._side = 1., 0., 0.
+ self._up = 0., 1., 0.
+ self._direction = 0., 0., -1.
+ self.setOrientation(direction=direction, up=up) # set _direction, _up
+
+ def _makeMatrix(self):
+ return transform.mat4LookAtDir(self._position,
+ self._direction, self._up)
+
+ def copy(self):
+ """Return an independent copy"""
+ return CameraExtrinsic(self.position, self.direction, self.up)
+
+ def setOrientation(self, direction=None, up=None):
+ """Set the rotation of the point of view.
+
+ :param direction: Sight direction vector or
+ None to keep the current one.
+ :type direction: numpy.ndarray-like of 3 float32 or None.
+ :param up: Vector pointing upward in the image plane or
+ None to keep the current one.
+ :type up: numpy.ndarray-like of 3 float32 or None.
+ :raises RuntimeError: if the direction and up are parallel.
+ """
+ if direction is None: # Use current direction
+ direction = self.direction
+ else:
+ assert len(direction) == 3
+ direction = numpy.array(direction, copy=True, dtype=numpy.float32)
+ direction /= numpy.linalg.norm(direction)
+
+ if up is None: # Use current up
+ up = self.up
+ else:
+ assert len(up) == 3
+ up = numpy.array(up, copy=True, dtype=numpy.float32)
+
+ # Update side and up to make sure they are perpendicular and normalized
+ side = numpy.cross(direction, up)
+ sidenormal = numpy.linalg.norm(side)
+ if sidenormal == 0.:
+ raise RuntimeError('direction and up vectors are parallel.')
+ # Alternative: when one of the input parameter is None, it is
+ # possible to guess correct vectors using previous direction and up
+ side /= sidenormal
+ up = numpy.cross(side, direction)
+ up /= numpy.linalg.norm(up)
+
+ self._side = side
+ self._up = up
+ self._direction = direction
+ self.notify()
+
+ @property
+ def position(self):
+ """Coordinates of the point of view as a numpy.ndarray of 3 float32."""
+ return self._position.copy()
+
+ @position.setter
+ def position(self, position):
+ assert len(position) == 3
+ self._position = numpy.array(position, copy=True, dtype=numpy.float32)
+ self.notify()
+
+ @property
+ def direction(self):
+ """Sight direction (ndarray of 3 float32)."""
+ return self._direction.copy()
+
+ @direction.setter
+ def direction(self, direction):
+ self.setOrientation(direction=direction)
+
+ @property
+ def up(self):
+ """Vector pointing upward in the image plane (ndarray of 3 float32).
+ """
+ return self._up.copy()
+
+ @up.setter
+ def up(self, up):
+ self.setOrientation(up=up)
+
+ @property
+ def side(self):
+ """Vector pointing towards the side of the image plane.
+
+ ndarray of 3 float32"""
+ return self._side.copy()
+
+ def move(self, direction, step=1.):
+ """Move the camera relative to the image plane.
+
+ :param str direction: Direction relative to image plane.
+ One of: 'up', 'down', 'left', 'right',
+ 'forward', 'backward'.
+ :param float step: The step of the pan to perform in the coordinate
+ in which the camera position is defined.
+ """
+ if direction in ('up', 'down'):
+ vector = self.up * (1. if direction == 'up' else -1.)
+ elif direction in ('left', 'right'):
+ vector = self.side * (1. if direction == 'right' else -1.)
+ elif direction in ('forward', 'backward'):
+ vector = self.direction * (1. if direction == 'forward' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ self.position += step * vector
+
+ def rotate(self, direction, angle=1.):
+ """First-person rotation of the camera towards the direction.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param float angle: The angle in degrees of the rotation.
+ """
+ if direction in ('up', 'down'):
+ axis = self.side * (1. if direction == 'up' else -1.)
+ elif direction in ('left', 'right'):
+ axis = self.up * (1. if direction == 'left' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ matrix = transform.mat4RotateFromAngleAxis(numpy.radians(angle), *axis)
+ newdir = numpy.dot(matrix[:3, :3], self.direction)
+
+ if direction in ('up', 'down'):
+ # Rotate up to avoid up and new direction to be (almost) co-linear
+ newup = numpy.dot(matrix[:3, :3], self.up)
+ self.setOrientation(newdir, newup)
+ else:
+ # No need to rotate up here as it is the rotation axis
+ self.direction = newdir
+
+ def orbit(self, direction, center=(0., 0., 0.), angle=1.):
+ """Rotate the camera around a point.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param center: Position around which to rotate the point of view.
+ :type center: numpy.ndarray-like of 3 float32.
+ :param float angle: he angle in degrees of the rotation.
+ """
+ if direction in ('up', 'down'):
+ axis = self.side * (1. if direction == 'down' else -1.)
+ elif direction in ('left', 'right'):
+ axis = self.up * (1. if direction == 'right' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ # Rotate viewing direction
+ rotmatrix = transform.mat4RotateFromAngleAxis(
+ numpy.radians(angle), *axis)
+ self.direction = numpy.dot(rotmatrix[:3, :3], self.direction)
+
+ # Rotate position around center
+ center = numpy.array(center, copy=False, dtype=numpy.float32)
+ matrix = numpy.dot(transform.mat4Translate(*center), rotmatrix)
+ matrix = numpy.dot(matrix, transform.mat4Translate(*(-center)))
+ position = numpy.append(self.position, 1.)
+ self.position = numpy.dot(matrix, position)[:3]
+
+ _RESET_CAMERA_ORIENTATIONS = {
+ 'side': ((-1., -1., -1.), (0., 1., 0.)),
+ 'front': ((0., 0., -1.), (0., 1., 0.)),
+ 'back': ((0., 0., 1.), (0., 1., 0.)),
+ 'top': ((0., -1., 0.), (0., 0., -1.)),
+ 'bottom': ((0., 1., 0.), (0., 0., 1.)),
+ 'right': ((-1., 0., 0.), (0., 1., 0.)),
+ 'left': ((1., 0., 0.), (0., 1., 0.))
+ }
+
+ def reset(self, face=None):
+ """Reset the camera position to pre-defined orientations.
+
+ :param str face: The direction of the camera in:
+ side, front, back, top, bottom, right, left.
+ """
+ if face not in self._RESET_CAMERA_ORIENTATIONS:
+ raise ValueError('Unsupported face: %s' % face)
+
+ distance = numpy.linalg.norm(self.position)
+ direction, up = self._RESET_CAMERA_ORIENTATIONS[face]
+ self.setOrientation(direction, up)
+ self.position = - self.direction * distance
+
+
+class Camera(transform.Transform):
+ """Combination of camera projection and position.
+
+ See :class:`Perspective` and :class:`CameraExtrinsic`.
+
+ :param float fovy: Vertical field-of-view in degrees.
+ :param float near: The near clipping plane Z coord (strictly positive).
+ :param float far: The far clipping plane Z coord (> near).
+ :param size: Viewport's size used to compute the aspect ratio.
+ :type size: 2-tuple of float (width, height).
+ :param position: Coordinates of the point of view.
+ :type position: numpy.ndarray-like of 3 float32.
+ :param direction: Sight direction vector.
+ :type direction: numpy.ndarray-like of 3 float32.
+ :param up: Vector pointing upward in the image plane.
+ :type up: numpy.ndarray-like of 3 float32.
+ """
+
+ def __init__(self, fovy=30., near=0.1, far=1., size=(1., 1.),
+ position=(0., 0., 0.),
+ direction=(0., 0., -1.), up=(0., 1., 0.)):
+ super(Camera, self).__init__()
+ self._intrinsic = transform.Perspective(fovy, near, far, size)
+ self._intrinsic.addListener(self._transformChanged)
+ self._extrinsic = CameraExtrinsic(position, direction, up)
+ self._extrinsic.addListener(self._transformChanged)
+
+ def _makeMatrix(self):
+ return numpy.dot(self.intrinsic.matrix, self.extrinsic.matrix)
+
+ def _transformChanged(self, source):
+ """Listener of intrinsic and extrinsic camera parameters instances."""
+ if source is not self:
+ self.notify()
+
+ def resetCamera(self, bounds):
+ """Change camera to have the bounds in the viewing frustum.
+
+ It updates the camera position and depth extent.
+ Camera sight direction and up are not affected.
+
+ :param bounds: The axes-aligned bounds to include.
+ :type bounds: numpy.ndarray: ((xMin, yMin, zMin), (xMax, yMax, zMax))
+ """
+
+ center = 0.5 * (bounds[0] + bounds[1])
+ radius = numpy.linalg.norm(0.5 * (bounds[1] - bounds[0]))
+
+ if isinstance(self.intrinsic, transform.Perspective):
+ # Get the viewpoint distance from the bounds center
+ minfov = numpy.radians(self.intrinsic.fovy)
+ width, height = self.intrinsic.size
+ if width < height:
+ minfov *= width / height
+
+ offset = radius / numpy.sin(0.5 * minfov)
+
+ # Update camera
+ self.extrinsic.position = \
+ center - offset * self.extrinsic.direction
+ self.intrinsic.setDepthExtent(offset - radius, offset + radius)
+
+ elif isinstance(self.intrinsic, transform.Orthographic):
+ # Y goes up
+ self.intrinsic.setClipping(
+ left=center[0] - radius,
+ right=center[0] + radius,
+ bottom=center[1] - radius,
+ top=center[1] + radius)
+
+ # Update camera
+ self.extrinsic.position = 0, 0, 0
+ self.intrinsic.setDepthExtent(center[2] - radius,
+ center[2] + radius)
+ else:
+ raise RuntimeError('Unsupported camera: %s' % self.intrinsic)
+
+ @property
+ def intrinsic(self):
+ """Intrinsic camera parameters, i.e., projection matrix."""
+ return self._intrinsic
+
+ @intrinsic.setter
+ def intrinsic(self, intrinsic):
+ self._intrinsic.removeListener(self._transformChanged)
+ self._intrinsic = intrinsic
+ self._intrinsic.addListener(self._transformChanged)
+
+ @property
+ def extrinsic(self):
+ """Extrinsic camera parameters, i.e., position and orientation."""
+ return self._extrinsic
+
+ def move(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.move`."""
+ self.extrinsic.move(*args, **kwargs)
+
+ def rotate(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.rotate`."""
+ self.extrinsic.rotate(*args, **kwargs)
+
+ def orbit(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.orbit`."""
+ self.extrinsic.orbit(*args, **kwargs)
diff --git a/silx/gui/plot3d/scene/core.py b/silx/gui/plot3d/scene/core.py
new file mode 100644
index 0000000..a293f28
--- /dev/null
+++ b/silx/gui/plot3d/scene/core.py
@@ -0,0 +1,334 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base scene structure.
+
+This module provides the classes for describing a tree structure with
+rendering and picking API.
+All nodes inherit from :class:`Base`.
+Nodes with children are provided with :class:`PrivateGroup` and
+:class:`Group` classes.
+Leaf rendering nodes should inherit from :class:`Elem`.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import itertools
+import weakref
+
+import numpy
+
+from . import event
+from . import transform
+
+from .viewport import Viewport
+
+
+# Nodes #######################################################################
+
+class Base(event.Notifier):
+ """A scene node with common features."""
+
+ def __init__(self):
+ super(Base, self).__init__()
+ self._visible = True
+ self._pickable = False
+
+ self._parentRef = None
+
+ self._transforms = transform.TransformList()
+ self._transforms.addListener(self._transformChanged)
+
+ # notifying properties
+
+ visible = event.notifyProperty('_visible',
+ doc="Visibility flag of the node")
+ pickable = event.notifyProperty('_pickable',
+ doc="True to make node pickable")
+
+ # Access to tree path
+
+ @property
+ def parent(self):
+ """Parent or None if no parent"""
+ return None if self._parentRef is None else self._parentRef()
+
+ def _setParent(self, parent):
+ """Set the parent of this node.
+
+ For internal use.
+
+ :param Base parent: The parent.
+ """
+ if parent is not None and self._parentRef is not None:
+ raise RuntimeError('Trying to add a node at two places.')
+ # Alternative: remove it from previous children list
+ self._parentRef = None if parent is None else weakref.ref(parent)
+
+ @property
+ def path(self):
+ """Tuple of scene nodes, from the tip of the tree down to this node.
+
+ If this tree is attached to a :class:`Viewport`,
+ then the :class:`Viewport` is the first element of path.
+ """
+ if self.parent is None:
+ return self,
+ elif isinstance(self.parent, Viewport):
+ return self.parent, self
+ else:
+ return self.parent.path + (self, )
+
+ @property
+ def viewport(self):
+ """The viewport this node is attached to or None."""
+ root = self.path[0]
+ return root if isinstance(root, Viewport) else None
+
+ @property
+ def objectToNDCTransform(self):
+ """Transform from object to normalized device coordinates.
+
+ Do not forget perspective divide.
+ """
+ # Using the Viewport's transforms property to proxy the camera
+ path = self.path
+ assert isinstance(path[0], Viewport)
+ return transform.StaticTransformList(elem.transforms for elem in path)
+
+ @property
+ def objectToSceneTransform(self):
+ """Transform from object to scene.
+
+ Combine transforms up to the Viewport (not including it).
+ """
+ path = self.path
+ if isinstance(path[0], Viewport):
+ path = path[1:] # Remove viewport to remove camera transforms
+ return transform.StaticTransformList(elem.transforms for elem in path)
+
+ # transform
+
+ @property
+ def transforms(self):
+ """List of transforms defining the frame of this node relative
+ to its parent."""
+ return self._transforms
+
+ @transforms.setter
+ def transforms(self, iterable):
+ self._transforms.removeListener(self._transformChanged)
+ if isinstance(iterable, transform.TransformList):
+ # If it is a TransformList, do not create one to enable sharing.
+ self._transforms = iterable
+ else:
+ assert hasattr(iterable, '__iter__')
+ self._transforms = transform.TransformList(iterable)
+ self._transforms.addListener(self._transformChanged)
+
+ def _transformChanged(self, source):
+ self.notify() # Broadcast transform notification
+
+ # Bounds
+
+ _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
+ dtype=numpy.float32)
+ """Unit cube corners used to transform bounds"""
+
+ def _bounds(self, dataBounds=False):
+ """Override in subclass to provide bounds in object coordinates"""
+ return None
+
+ def bounds(self, transformed=False, dataBounds=False):
+ """Returns the bounds of this node aligned with the axis,
+ with or without transform applied.
+
+ :param bool transformed: False to give bounds in object coordinates
+ (the default), True to apply this object's
+ transforms.
+ :param bool dataBounds: False to give bounds of vertices (the default),
+ True to give bounds of the represented data.
+ :return: The bounds: ((xMin, yMin, zMin), (xMax, yMax, zMax)) or None
+ if no bounds.
+ :rtype: numpy.ndarray of float
+ """
+ bounds = self._bounds(dataBounds)
+
+ if transformed and bounds is not None:
+ bounds = self.transforms.transformBounds(bounds)
+
+ return bounds
+
+ # Rendering
+
+ def prepareGL2(self, ctx):
+ """Called before the rendering to prepare OpenGL resources.
+
+ Override in subclass.
+ """
+ pass
+
+ def renderGL2(self, ctx):
+ """Called to perform the OpenGL rendering.
+
+ Override in subclass.
+ """
+ pass
+
+ def render(self, ctx):
+ """Called internally to perform rendering."""
+ if self.visible:
+ ctx.pushTransform(self.transforms)
+ self.prepareGL2(ctx)
+ self.renderGL2(ctx)
+ ctx.popTransform()
+
+ def postRender(self, ctx):
+ """Hook called when parent's node render is finished.
+
+ Called in the reverse of rendering order (i.e., last child first).
+
+ Meant for nodes that modify the :class:`RenderContext` ctx to
+ reset their modifications.
+ """
+ pass
+
+ def pick(self, ctx, x, y, depth=None):
+ """True/False picking, should be fast"""
+ if self.pickable:
+ pass
+
+ def pickRay(self, ctx, ray):
+ """Picking returning list of ray intersections."""
+ if self.pickable:
+ pass
+
+
+class Elem(Base):
+ """A scene node that does some rendering."""
+
+ def __init__(self):
+ super(Elem, self).__init__()
+ # self.showBBox = False # Here or outside scene graph?
+ # self.clipPlane = None # This needs to be handled in the shader
+
+
+class PrivateGroup(Base):
+ """A scene node that renders its (private) childern.
+
+ :param iterable children: :class:`Base` nodes to add as children
+ """
+
+ class ChildrenList(event.NotifierList):
+ """List of children with notification and children's parent update."""
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ super(PrivateGroup.ChildrenList, self)._listWillChangeHook(
+ methodName, *args, **kwargs)
+ for item in self:
+ item._setParent(None)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item._setParent(self._parentRef())
+ super(PrivateGroup.ChildrenList, self)._listWasChangedHook(
+ methodName, *args, **kwargs)
+
+ def __init__(self, parent, children):
+ self._parentRef = weakref.ref(parent)
+ super(PrivateGroup.ChildrenList, self).__init__(children)
+
+ def __init__(self, children=()):
+ super(PrivateGroup, self).__init__()
+ self.__children = PrivateGroup.ChildrenList(self, children)
+ self.__children.addListener(self._updated)
+
+ @property
+ def _children(self):
+ """List of children to be rendered.
+
+ This private attribute is meant to be used by subclass.
+ """
+ return self.__children
+
+ @_children.setter
+ def _children(self, iterable):
+ self.__children.removeListener(self._updated)
+ for item in self.__children:
+ item._setParent(None)
+ del self.__children # This is needed
+ self.__children = PrivateGroup.ChildrenList(self, iterable)
+ self.__children.addListener(self._updated)
+ self.notify()
+
+ def _updated(self, source, *args, **kwargs):
+ """Listen for updates"""
+ if source is not self: # Avoid infinite recursion
+ self.notify(*args, **kwargs)
+
+ def _bounds(self, dataBounds=False):
+ """Compute the bounds from transformed children bounds"""
+ bounds = []
+ for child in self._children:
+ if child.visible:
+ childBounds = child.bounds(
+ transformed=True, dataBounds=dataBounds)
+ if childBounds is not None:
+ bounds.append(childBounds)
+
+ if len(bounds) == 0:
+ return None
+ else:
+ bounds = numpy.array(bounds, dtype=numpy.float32)
+ return numpy.array((bounds[:, 0, :].min(axis=0),
+ bounds[:, 1, :].max(axis=0)),
+ dtype=numpy.float32)
+
+ def prepareGL2(self, ctx):
+ pass
+
+ def renderGL2(self, ctx):
+ """Render all children"""
+ for child in self._children:
+ child.render(ctx)
+ for child in reversed(self._children):
+ child.postRender(ctx)
+
+
+class Group(PrivateGroup):
+ """A scene node that renders its (public) children."""
+
+ @property
+ def children(self):
+ """List of children to be rendered."""
+ return self._children
+
+ @children.setter
+ def children(self, iterable):
+ self._children = iterable
diff --git a/silx/gui/plot3d/scene/cutplane.py b/silx/gui/plot3d/scene/cutplane.py
new file mode 100644
index 0000000..79b4168
--- /dev/null
+++ b/silx/gui/plot3d/scene/cutplane.py
@@ -0,0 +1,374 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A cut plane in a 3D texture: hackish implementation...
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/10/2016"
+
+import string
+import numpy
+
+from ... import _glutils
+from ..._glutils import gl
+
+from .function import Colormap
+from .primitives import Box, Geometry, PlaneInGroup
+from . import transform, utils
+
+
+class ColormapMesh3D(Geometry):
+ """A 3D mesh with color from a 3D texture."""
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ //uniform mat3 matrixInvTranspose;
+ uniform vec3 dataScale;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec3 vTexCoords;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ //vNormal = matrixInvTranspose * normalize(normal);
+ vPosition = position;
+ vTexCoords = dataScale * position;
+ vNormal = normal;
+ gl_Position = matrix * vec4(position, 1.0);
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec3 vTexCoords;
+ uniform sampler3D data;
+ uniform float alpha;
+
+ $colormapDecl
+
+ $clippingDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ float value = texture3D(data, vTexCoords).r;
+ vec4 color = $colormapCall(value);
+ color.a = alpha;
+
+ $clippingCall(vCameraPosition);
+
+ gl_FragColor = $lightingCall(color, vPosition, vNormal);
+ }
+ """))
+
+ def __init__(self, position, normal, data, copy=True,
+ mode='triangles', indices=None, colormap=None):
+ assert mode in self._TRIANGLE_MODES
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ self._texture = None
+ self._update_texture = True
+ self._update_texture_filter = False
+ self._alpha = 1.
+ self._colormap = colormap or Colormap() # Default colormap
+ self._colormap.addListener(self._cmapChanged)
+ self._interpolation = 'linear'
+ super(ColormapMesh3D, self).__init__(mode,
+ indices,
+ position=position,
+ normal=normal)
+
+ self.isBackfaceVisible = True
+
+ def setData(self, data, copy=True):
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ self._update_texture = True
+
+ def getData(self, copy=True):
+ return numpy.array(self._data, copy=copy)
+
+ @property
+ def interpolation(self):
+ """The texture interpolation mode: 'linear' or 'nearest'"""
+ return self._interpolation
+
+ @interpolation.setter
+ def interpolation(self, interpolation):
+ assert interpolation in ('linear', 'nearest')
+ self._interpolation = interpolation
+ self._update_texture_filter = True
+ self.notify()
+
+ @property
+ def alpha(self):
+ """Transparency of the plane, float in [0, 1]"""
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, alpha):
+ self._alpha = float(alpha)
+
+ @property
+ def colormap(self):
+ """The colormap used by this primitive"""
+ return self._colormap
+
+ def _cmapChanged(self, source, *args, **kwargs):
+ """Broadcast colormap changes"""
+ self.notify(*args, **kwargs)
+
+ def prepareGL2(self, ctx):
+ if self._texture is None or self._update_texture:
+ if self._texture is not None:
+ self._texture.discard()
+
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._update_texture = False
+ self._update_texture_filter = False
+ self._texture = _glutils.Texture(
+ gl.GL_R32F, self._data, gl.GL_RED,
+ minFilter=filter_,
+ magFilter=filter_,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+
+ if self._update_texture_filter:
+ self._update_texture_filter = False
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._texture.minFilter = filter_
+ self._texture.magFilter = filter_
+
+ super(ColormapMesh3D, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall,
+ colormapDecl=self.colormap.decl,
+ colormapCall=self.colormap.call
+ )
+ program = ctx.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ ctx.viewport.light.setupProgram(ctx, program)
+ self.colormap.setupProgram(ctx, program)
+
+ if not self.isBackfaceVisible:
+ gl.glCullFace(gl.GL_BACK)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+ gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+
+ shape = self._data.shape
+ scales = 1./shape[2], 1./shape[1], 1./shape[0]
+ gl.glUniform3f(program.uniforms['dataScale'], *scales)
+
+ gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
+
+ ctx.clipper.setupProgram(ctx, program)
+
+ self._texture.bind()
+ self._draw(program)
+
+ if not self.isBackfaceVisible:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+class CutPlane(PlaneInGroup):
+ """A cutting plane in a 3D texture"""
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ self._data = None
+ self._mesh = None
+ self._alpha = 1.
+ self._interpolation = 'linear'
+ self._colormap = Colormap()
+ super(CutPlane, self).__init__(point, normal)
+
+ def setData(self, data, copy=True):
+ if data is None:
+ self._data = None
+ if self._mesh is not None:
+ self._children.remove(self._mesh)
+ self._mesh = None
+
+ else:
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ if self._mesh is not None:
+ self._mesh.setData(data, copy=False)
+
+ def getData(self, copy=True):
+ return None if self._mesh is None else self._mesh.getData(copy=copy)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, alpha):
+ self._alpha = float(alpha)
+ if self._mesh is not None:
+ self._mesh.alpha = alpha
+
+ @property
+ def colormap(self):
+ return self._colormap
+
+ @property
+ def interpolation(self):
+ """The texture interpolation mode: 'linear' (default) or 'nearest'"""
+ return self._interpolation
+
+ @interpolation.setter
+ def interpolation(self, interpolation):
+ assert interpolation in ('nearest', 'linear')
+ if interpolation != self.interpolation:
+ self._interpolation = interpolation
+ if self._mesh is not None:
+ self._mesh.interpolation = interpolation
+
+ def prepareGL2(self, ctx):
+ if self.isValid:
+
+ contourVertices = self.contourVertices
+
+ if (self.interpolation == 'nearest' and
+ contourVertices is not None and len(contourVertices)):
+ # Avoid cut plane co-linear with array bin edges
+ for index, normal in enumerate(((1., 0., 0.), (0., 1., 0.), (0., 0., 1.))):
+ if (numpy.all(numpy.equal(self.plane.normal, normal)) and
+ int(self.plane.point[index]) == self.plane.point[index]):
+ contourVertices += self.plane.normal * 0.01 # Add an offset
+ break
+
+ if self._mesh is None and self._data is not None:
+ self._mesh = ColormapMesh3D(contourVertices,
+ normal=self.plane.normal,
+ data=self._data,
+ copy=False,
+ mode='fan',
+ colormap=self.colormap)
+ self._mesh.alpha = self._alpha
+ self._interpolation = self.interpolation
+ self._children.insert(0, self._mesh)
+
+ if self._mesh is not None:
+ if (contourVertices is None or
+ len(contourVertices) == 0):
+ self._mesh.visible = False
+ else:
+ self._mesh.visible = True
+ self._mesh.setAttribute('normal', self.plane.normal)
+ self._mesh.setAttribute('position', contourVertices)
+
+ super(CutPlane, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ with self.viewport.light.turnOff():
+ super(CutPlane, self).renderGL2(ctx)
+
+ def _bounds(self, dataBounds=False):
+ if not dataBounds:
+ vertices = self.contourVertices
+ if vertices is not None:
+ return numpy.array(
+ (vertices.min(axis=0), vertices.max(axis=0)),
+ dtype=numpy.float32)
+ else:
+ return None # Plane in not slicing the data volume
+ else:
+ if self._data is None:
+ return None
+ else:
+ depth, height, width = self._data.shape
+ return numpy.array(((0., 0., 0.),
+ (width, height, depth)),
+ dtype=numpy.float32)
+
+ @property
+ def contourVertices(self):
+ """The vertices of the contour of the plane/bounds intersection."""
+ # TODO copy from PlaneInGroup, refactor all that!
+ bounds = self.bounds(dataBounds=True)
+ if bounds is None:
+ return None # No bounds: no vertices
+
+ # Check if cache is valid and return it
+ cachebounds, cachevertices = self._cache
+ if numpy.all(numpy.equal(bounds, cachebounds)):
+ return cachevertices
+
+ # Cache is not OK, rebuild it
+ boxvertices = bounds[0] + Box._vertices.copy()*(bounds[1] - bounds[0])
+ lineindices = Box._lineIndices
+ vertices = utils.boxPlaneIntersect(
+ boxvertices, lineindices, self.plane.normal, self.plane.point)
+
+ self._cache = bounds, vertices if len(vertices) != 0 else None
+
+ return self._cache[1]
+
+ # Render transforms RW, TODO refactor this!
+ @property
+ def transforms(self):
+ return self._transforms
+
+ @transforms.setter
+ def transforms(self, iterable):
+ self._transforms.removeListener(self._transformChanged)
+ if isinstance(iterable, transform.TransformList):
+ # If it is a TransformList, do not create one to enable sharing.
+ self._transforms = iterable
+ else:
+ assert hasattr(iterable, '__iter__')
+ self._transforms = transform.TransformList(iterable)
+ self._transforms.addListener(self._transformChanged)
diff --git a/silx/gui/plot3d/scene/event.py b/silx/gui/plot3d/scene/event.py
new file mode 100644
index 0000000..7b85434
--- /dev/null
+++ b/silx/gui/plot3d/scene/event.py
@@ -0,0 +1,225 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a simple generic notification system."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+
+from silx.utils.weakref import WeakList
+
+_logger = logging.getLogger(__name__)
+
+
+# Notifier ####################################################################
+
+class Notifier(object):
+ """Base class for object with notification mechanism."""
+
+ def __init__(self):
+ self._listeners = WeakList()
+
+ def addListener(self, listener):
+ """Register a listener.
+
+ Adding an already registered listener has no effect.
+
+ :param callable listener: The function or method to register.
+ """
+ if listener not in self._listeners:
+ self._listeners.append(listener)
+ else:
+ _logger.warning('Ignoring addition of an already registered listener')
+
+ def removeListener(self, listener):
+ """Remove a previously registered listener.
+
+ :param callable listener: The function or method to unregister.
+ """
+ try:
+ self._listeners.remove(listener)
+ except ValueError:
+ _logger.warn('Trying to remove a listener that is not registered')
+
+ def notify(self, *args, **kwargs):
+ """Notify all registered listeners with the given parameters.
+
+ Listeners are called directly in this method.
+ Listeners are called in the order they were registered.
+ """
+ for listener in self._listeners:
+ listener(self, *args, **kwargs)
+
+
+def notifyProperty(attrName, copy=False, converter=None, doc=None):
+ """Create a property that adds notification to an attribute.
+
+ :param str attrName: The name of the attribute to wrap.
+ :param bool copy: Whether to return a copy of the attribute
+ or not (the default).
+ :param converter: Function converting input value to appropriate type
+ This function takes a single argument and return the
+ converted value.
+ It can be used to perform some asserts.
+ :param str doc: The docstring of the property
+ :return: A property with getter and setter
+ """
+ if copy:
+ def getter(self):
+ return getattr(self, attrName).copy()
+ else:
+ def getter(self):
+ return getattr(self, attrName)
+
+ if converter is None:
+ def setter(self, value):
+ if getattr(self, attrName) != value:
+ setattr(self, attrName, value)
+ self.notify()
+
+ else:
+ def setter(self, value):
+ value = converter(value)
+ if getattr(self, attrName) != value:
+ setattr(self, attrName, value)
+ self.notify()
+
+ return property(getter, setter, doc=doc)
+
+
+class HookList(list):
+ """List with hooks before and after modification."""
+
+ def __init__(self, iterable):
+ super(HookList, self).__init__(iterable)
+
+ self._listWasChangedHook('__init__', iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ """To override. Called before modifying the list.
+
+ This method is called with the name of the method called to
+ modify the list and its parameters.
+ """
+ pass
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ """To override. Called after modifying the list.
+
+ This method is called with the name of the method called to
+ modify the list and its parameters.
+ """
+ pass
+
+ # Wrapping methods that modify the list
+
+ def _wrapper(self, methodName, *args, **kwargs):
+ """Generic wrapper of list methods calling the hooks."""
+ self._listWillChangeHook(methodName, *args, **kwargs)
+ result = getattr(super(HookList, self),
+ methodName)(*args, **kwargs)
+ self._listWasChangedHook(methodName, *args, **kwargs)
+ return result
+
+ # Add methods
+
+ def __iadd__(self, *args, **kwargs):
+ return self._wrapper('__iadd__', *args, **kwargs)
+
+ def __imul__(self, *args, **kwargs):
+ return self._wrapper('__imul__', *args, **kwargs)
+
+ def append(self, *args, **kwargs):
+ return self._wrapper('append', *args, **kwargs)
+
+ def extend(self, *args, **kwargs):
+ return self._wrapper('extend', *args, **kwargs)
+
+ def insert(self, *args, **kwargs):
+ return self._wrapper('insert', *args, **kwargs)
+
+ # Remove methods
+
+ def __delitem__(self, *args, **kwargs):
+ return self._wrapper('__delitem__', *args, **kwargs)
+
+ def __delslice__(self, *args, **kwargs):
+ return self._wrapper('__delslice__', *args, **kwargs)
+
+ def remove(self, *args, **kwargs):
+ return self._wrapper('remove', *args, **kwargs)
+
+ def pop(self, *args, **kwargs):
+ return self._wrapper('pop', *args, **kwargs)
+
+ # Set methods
+
+ def __setitem__(self, *args, **kwargs):
+ return self._wrapper('__setitem__', *args, **kwargs)
+
+ def __setslice__(self, *args, **kwargs):
+ return self._wrapper('__setslice__', *args, **kwargs)
+
+ # In place methods
+
+ def sort(self, *args, **kwargs):
+ return self._wrapper('sort', *args, **kwargs)
+
+ def reverse(self, *args, **kwargs):
+ return self._wrapper('reverse', *args, **kwargs)
+
+
+class NotifierList(HookList, Notifier):
+ """List of Notifiers with notification mechanism.
+
+ This class registers itself as a listener of the list items.
+
+ The default listener method forward notification from list items
+ to the listeners of the list.
+ """
+
+ def __init__(self, iterable=()):
+ Notifier.__init__(self)
+ HookList.__init__(self, iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.removeListener(self._notified)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.addListener(self._notified)
+ self.notify()
+
+ def _notified(self, source, *args, **kwargs):
+ """Default listener forwarding list item changes to its listeners."""
+ # Avoid infinite recursion if the list is listening itself
+ if source is not self:
+ self.notify(*args, **kwargs)
diff --git a/silx/gui/plot3d/scene/function.py b/silx/gui/plot3d/scene/function.py
new file mode 100644
index 0000000..80ac820
--- /dev/null
+++ b/silx/gui/plot3d/scene/function.py
@@ -0,0 +1,471 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides functions to add to shaders."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/11/2016"
+
+
+import contextlib
+import logging
+import numpy
+
+from ..._glutils import gl
+
+from . import event
+from . import utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ProgramFunction(object):
+ """Class providing a function to add to a GLProgram shaders.
+ """
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ pass
+
+
+class ClippingPlane(ProgramFunction):
+ """Description of a clipping plane and rendering.
+
+ Convention: Clipping is performed in camera/eye space.
+
+ :param point: Local coordinates of a point on the plane.
+ :type point: numpy.ndarray-like of 3 float32
+ :param normal: Local coordinates of the plane normal.
+ :type normal: numpy.ndarray-like of 3 float32
+ """
+
+ _fragDecl = """
+ /* Clipping plane */
+ /* as rx + gy + bz + a > 0, clipping all positive */
+ uniform vec4 planeEq;
+
+ /* Position is in camera/eye coordinates */
+
+ bool isClipped(vec4 position) {
+ vec4 tmp = planeEq * position;
+ float value = tmp.x + tmp.y + tmp.z + planeEq.a;
+ return (value < 0.0001);
+ }
+
+ void clipping(vec4 position) {
+ if (isClipped(position)) {
+ discard;
+ }
+ }
+ /* End of clipping */
+ """
+
+ _fragDeclNoop = """
+ bool isClipped(vec4 position)
+ {
+ return false;
+ }
+
+ void clipping(vec4 position) {}
+ """
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ self._plane = utils.Plane(point, normal)
+
+ @property
+ def plane(self):
+ """Plane parameters in camera space."""
+ return self._plane
+
+ # GL2
+
+ @property
+ def fragDecl(self):
+ return self._fragDecl if self.plane.isPlane else self._fragDeclNoop
+
+ @property
+ def fragCall(self):
+ return "clipping"
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ if self.plane.isPlane:
+ gl.glUniform4f(program.uniforms['planeEq'], *self.plane.parameters)
+
+
+class DirectionalLight(event.Notifier, ProgramFunction):
+ """Description of a directional Phong light.
+
+ :param direction: The direction of the light or None to disable light
+ :type direction: ndarray of 3 floats or None
+ :param ambient: RGB ambient light
+ :type ambient: ndarray of 3 floats in [0, 1], default: (1., 1., 1.)
+ :param diffuse: RGB diffuse light parameter
+ :type diffuse: ndarray of 3 floats in [0, 1], default: (0., 0., 0.)
+ :param specular: RGB specular light parameter
+ :type specular: ndarray of 3 floats in [0, 1], default: (1., 1., 1.)
+ :param int shininess: The shininess of the material for specular term,
+ default: 0 which disables specular component.
+ """
+
+ fragmentShaderFunction = """
+ /* Lighting */
+ struct DLight {
+ vec3 lightDir; // Direction of light in object space
+ vec3 ambient;
+ vec3 diffuse;
+ vec3 specular;
+ float shininess;
+ vec3 viewPos; // Camera position in object space
+ };
+
+ uniform DLight dLight;
+
+ vec4 lighting(vec4 color, vec3 position, vec3 normal)
+ {
+ normal = normalize(normal);
+ // 1-sided
+ float nDotL = max(0.0, dot(normal, - dLight.lightDir));
+
+ // 2-sided
+ //float nDotL = dot(normal, - dLight.lightDir);
+ //if (nDotL < 0.) {
+ // nDotL = - nDotL;
+ // normal = - normal;
+ //}
+
+ float specFactor = 0.;
+ if (dLight.shininess > 0. && nDotL > 0.) {
+ vec3 reflection = reflect(dLight.lightDir, normal);
+ vec3 viewDir = normalize(dLight.viewPos - position);
+ specFactor = max(0.0, dot(reflection, viewDir));
+ if (specFactor > 0.) {
+ specFactor = pow(specFactor, dLight.shininess);
+ }
+ }
+
+ vec3 enlightedColor = color.rgb * (dLight.ambient +
+ dLight.diffuse * nDotL) +
+ dLight.specular * specFactor;
+
+ return vec4(enlightedColor.rgb, color.a);
+ }
+ /* End of Lighting */
+ """
+
+ fragmentShaderFunctionNoop = """
+ vec4 lighting(vec4 color, vec3 position, vec3 normal)
+ {
+ return color;
+ }
+ """
+
+ def __init__(self, direction=None,
+ ambient=(1., 1., 1.), diffuse=(0., 0., 0.),
+ specular=(1., 1., 1.), shininess=0):
+ super(DirectionalLight, self).__init__()
+ self._direction = None
+ self.direction = direction # Set _direction
+ self._isOn = True
+ self._ambient = ambient
+ self._diffuse = diffuse
+ self._specular = specular
+ self._shininess = shininess
+
+ ambient = event.notifyProperty('_ambient')
+ diffuse = event.notifyProperty('_diffuse')
+ specular = event.notifyProperty('_specular')
+ shininess = event.notifyProperty('_shininess')
+
+ @property
+ def isOn(self):
+ """True if light is on, False otherwise."""
+ return self._isOn and self._direction is not None
+
+ @isOn.setter
+ def isOn(self, isOn):
+ self._isOn = bool(isOn)
+
+ @contextlib.contextmanager
+ def turnOff(self):
+ """Context manager to temporary turn off lighting during rendering.
+
+ >>> with light.turnOff():
+ ... # Do some rendering without lighting
+ """
+ wason = self._isOn
+ self._isOn = False
+ yield
+ self._isOn = wason
+
+ @property
+ def direction(self):
+ """The direction of the light, or None if light is not on."""
+ return self._direction
+
+ @direction.setter
+ def direction(self, direction):
+ if direction is None:
+ self._direction = None
+ else:
+ assert len(direction) == 3
+ direction = numpy.array(direction, dtype=numpy.float32, copy=True)
+ norm = numpy.linalg.norm(direction)
+ assert norm != 0
+ self._direction = direction / norm
+ self.notify()
+
+ # GL2
+
+ @property
+ def fragmentDef(self):
+ """Definition to add to fragment shader"""
+ if self.isOn:
+ return self.fragmentShaderFunction
+ else:
+ return self.fragmentShaderFunctionNoop
+
+ @property
+ def fragmentCall(self):
+ """Function name to call in fragment shader"""
+ return "lighting"
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ if self.isOn and self._direction is not None:
+ # Transform light direction from camera space to object coords
+ lightdir = context.objectToCamera.transformDir(
+ self._direction, direct=False)
+ lightdir /= numpy.linalg.norm(lightdir)
+
+ gl.glUniform3f(program.uniforms['dLight.lightDir'], *lightdir)
+
+ # Convert view position to object coords
+ viewpos = context.objectToCamera.transformPoint(
+ numpy.array((0., 0., 0., 1.), dtype=numpy.float32),
+ direct=False,
+ perspectiveDivide=True)[:3]
+ gl.glUniform3f(program.uniforms['dLight.viewPos'], *viewpos)
+
+ gl.glUniform3f(program.uniforms['dLight.ambient'], *self.ambient)
+ gl.glUniform3f(program.uniforms['dLight.diffuse'], *self.diffuse)
+ gl.glUniform3f(program.uniforms['dLight.specular'], *self.specular)
+ gl.glUniform1f(program.uniforms['dLight.shininess'],
+ self.shininess)
+
+
+class Colormap(event.Notifier, ProgramFunction):
+ # TODO use colors for out-of-bound values, for <=0 with log, for nan
+ # TODO texture-based colormap
+
+ decl = """
+ #define CMAP_GRAY 0
+ #define CMAP_R_GRAY 1
+ #define CMAP_RED 2
+ #define CMAP_GREEN 3
+ #define CMAP_BLUE 4
+ #define CMAP_TEMP 5
+
+ uniform struct {
+ int id;
+ bool isLog;
+ float min;
+ float oneOverRange;
+ } cmap;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 colormap(float value) {
+ if (cmap.isLog) { /* Log10 mapping */
+ if (value > 0.0) {
+ value = clamp(cmap.oneOverRange *
+ (oneOverLog10 * log(value) - cmap.min),
+ 0.0, 1.0);
+ } else {
+ value = 0.0;
+ }
+ } else { /* Linear mapping */
+ value = clamp(cmap.oneOverRange * (value - cmap.min), 0.0, 1.0);
+ }
+
+ if (cmap.id == CMAP_GRAY) {
+ return vec4(value, value, value, 1.0);
+ }
+ else if (cmap.id == CMAP_R_GRAY) {
+ float invValue = 1.0 - value;
+ return vec4(invValue, invValue, invValue, 1.0);
+ }
+ else if (cmap.id == CMAP_RED) {
+ return vec4(value, 0.0, 0.0, 1.0);
+ }
+ else if (cmap.id == CMAP_GREEN) {
+ return vec4(0.0, value, 0.0, 1.0);
+ }
+ else if (cmap.id == CMAP_BLUE) {
+ return vec4(0.0, 0.0, value, 1.0);
+ }
+ else if (cmap.id == CMAP_TEMP) {
+ //red: 0.5->0.75: 0->1
+ //green: 0.->0.25: 0->1; 0.75->1.: 1->0
+ //blue: 0.25->0.5: 1->0
+ return vec4(
+ clamp(4.0 * value - 2.0, 0.0, 1.0),
+ 1.0 - clamp(4.0 * abs(value - 0.5) - 1.0, 0.0, 1.0),
+ 1.0 - clamp(4.0 * value - 1.0, 0.0, 1.0),
+ 1.0);
+ }
+ else {
+ /* Unknown colormap */
+ return vec4(0.0, 0.0, 0.0, 1.0);
+ }
+ }
+ """
+
+ call = "colormap"
+
+ _COLORMAPS = {
+ 'gray': 0,
+ 'reversed gray': 1,
+ 'red': 2,
+ 'green': 3,
+ 'blue': 4,
+ 'temperature': 5
+ }
+
+ COLORMAPS = tuple(_COLORMAPS.keys())
+ """Tuple of supported colormap names."""
+
+ NORMS = 'linear', 'log'
+ """Tuple of supported normalizations."""
+
+ def __init__(self, name='gray', norm='linear', range_=(1., 10.)):
+ """Shader function to apply a colormap to a value.
+
+ :param str name: Name of the colormap.
+ :param str norm: Normalization to apply: 'linear' (default) or 'log'.
+ :param range_: Range of value to map to the colormap.
+ :type range_: 2-tuple of float (begin, end).
+ """
+ super(Colormap, self).__init__()
+
+ # Init privates to default
+ self._name, self._norm, self._range = 'gray', 'linear', (1., 10.)
+
+ # Set to param values through properties to go through asserts
+ self.name = name
+ self.norm = norm
+ self.range_ = range_
+
+ @property
+ def name(self):
+ """Name of the colormap in use."""
+ return self._name
+
+ @name.setter
+ def name(self, name):
+ if name != self._name:
+ assert name in self.COLORMAPS
+ self._name = name
+ self.notify()
+
+ @property
+ def norm(self):
+ """Normalization to use for colormap mapping.
+
+ Either 'linear' (the default) or 'log' for log10 mapping.
+ With 'log' normalization, values <= 0. are set to 1. (i.e. log == 0)
+ """
+ return self._norm
+
+ @norm.setter
+ def norm(self, norm):
+ if norm != self._norm:
+ assert norm in self.NORMS
+ self._norm = norm
+ if norm == 'log':
+ self.range_ = self.range_ # To test for positive range_
+ self.notify()
+
+ @property
+ def range_(self):
+ """Range of values to map to the colormap.
+
+ 2-tuple of floats: (begin, end).
+ The begin value is mapped to the origin of the colormap and the
+ end value is mapped to the other end of the colormap.
+ The colormap is reversed if begin > end.
+ """
+ return self._range
+
+ @range_.setter
+ def range_(self, range_):
+ assert len(range_) == 2
+ range_ = float(range_[0]), float(range_[1])
+
+ if self.norm == 'log' and (range_[0] <= 0. or range_[1] <= 0.):
+ _logger.warn(
+ "Log normalization and negative range: updating range.")
+ minPos = numpy.finfo(numpy.float32).tiny
+ range_ = max(range_[0], minPos), max(range_[1], minPos)
+
+ if range_ != self._range:
+ self._range = range_
+ self.notify()
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ gl.glUniform1i(program.uniforms['cmap.id'], self._COLORMAPS[self.name])
+ gl.glUniform1i(program.uniforms['cmap.isLog'], self._norm == 'log')
+
+ min_, max_ = self.range_
+ if self._norm == 'log':
+ min_, max_ = numpy.log10(min_), numpy.log10(max_)
+
+ gl.glUniform1f(program.uniforms['cmap.min'], min_)
+ gl.glUniform1f(program.uniforms['cmap.oneOverRange'],
+ (1. / (max_ - min_)) if max_ != min_ else 0.)
diff --git a/silx/gui/plot3d/scene/interaction.py b/silx/gui/plot3d/scene/interaction.py
new file mode 100644
index 0000000..68bfc13
--- /dev/null
+++ b/silx/gui/plot3d/scene/interaction.py
@@ -0,0 +1,652 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides interaction to plug on the scene graph."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+import logging
+import numpy
+
+from silx.gui.plot.Interaction import \
+ StateMachine, State, LEFT_BTN, RIGHT_BTN # , MIDDLE_BTN
+
+from . import transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+# ClickOrDrag #################################################################
+
+# TODO merge with silx.gui.plot.Interaction.ClickOrDrag
+class ClickOrDrag(StateMachine):
+ """Click or drag interaction for a given button."""
+
+ DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == self.machine.button:
+ self.goto('clickOrDrag', x, y)
+ return True
+
+ class ClickOrDrag(State):
+ def enterState(self, x, y):
+ self.initPos = x, y
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onMove(self, x, y):
+ dx = (x - self.initPos[0]) ** 2
+ dy = (y - self.initPos[1]) ** 2
+ if (dx ** 2 + dy ** 2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto('drag', self.initPos, (x, y))
+
+ def onRelease(self, x, y, btn):
+ if btn == self.machine.button:
+ self.machine.click(x, y)
+ self.goto('idle')
+
+ class Drag(State):
+ def enterState(self, initPos, curPos):
+ self.initPos = initPos
+ self.machine.beginDrag(*initPos)
+ self.machine.drag(*curPos)
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onMove(self, x, y):
+ self.machine.drag(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.machine.button:
+ self.machine.endDrag(self.initPos, (x, y))
+ self.goto('idle')
+
+ def __init__(self, button=LEFT_BTN):
+ self.button = button
+ states = {
+ 'idle': ClickOrDrag.Idle,
+ 'clickOrDrag': ClickOrDrag.ClickOrDrag,
+ 'drag': ClickOrDrag.Drag
+ }
+ super(ClickOrDrag, self).__init__(states, 'idle')
+
+ def click(self, x, y):
+ """Called upon a left or right button click.
+ To override in a subclass.
+ """
+ pass
+
+ def beginDrag(self, x, y):
+ """Called at the beginning of a drag gesture with left button
+ pressed.
+ To override in a subclass.
+ """
+ pass
+
+ def drag(self, x, y):
+ """Called on mouse moved during a drag gesture.
+ To override in a subclass.
+ """
+ pass
+
+ def endDrag(self, x, y):
+ """Called at the end of a drag gesture when the left button is
+ released.
+ To override in a subclass.
+ """
+ pass
+
+
+# CameraRotate ################################################################
+
+class CameraRotate(ClickOrDrag):
+ """Camera rotation using an arcball-like interaction."""
+
+ def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN):
+ self._viewport = viewport
+ self._orbitAroundCenter = orbitAroundCenter
+ self._reset()
+ super(CameraRotate, self).__init__(button)
+
+ def _reset(self):
+ self._origin, self._center = None, None
+ self._startExtrinsic = None
+
+ def click(self, x, y):
+ pass # No interaction yet
+
+ def beginDrag(self, x, y):
+ centerPos = None
+ if not self._orbitAroundCenter:
+ # Try to use picked object position as center of rotation
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ if ndcZ != 1.:
+ # Hit an object, use picked point as center
+ centerPos = self._viewport._getXZYGL(x, y) # Can return None
+
+ if centerPos is None:
+ # Not using picked position, use scene center
+ bounds = self._viewport.scene.bounds(transformed=True)
+ centerPos = 0.5 * (bounds[0] + bounds[1])
+
+ self._center = transform.Translate(*centerPos)
+ self._origin = x, y
+ self._startExtrinsic = self._viewport.camera.extrinsic.copy()
+
+ def drag(self, x, y):
+ if self._center is None:
+ return
+
+ dx, dy = self._origin[0] - x, self._origin[1] - y
+
+ if dx == 0 and dy == 0:
+ direction = self._startExtrinsic.direction
+ up = self._startExtrinsic.up
+ position = self._startExtrinsic.position
+ else:
+ minsize = min(self._viewport.size)
+ distance = numpy.sqrt(dx ** 2 + dy ** 2)
+ angle = distance / minsize * numpy.pi
+
+ # Take care of y inversion
+ direction = dx * self._startExtrinsic.side - \
+ dy * self._startExtrinsic.up
+ direction /= numpy.linalg.norm(direction)
+ axis = numpy.cross(direction, self._startExtrinsic.direction)
+ axis /= numpy.linalg.norm(axis)
+
+ # Orbit start camera with current angle and axis
+ # Rotate viewing direction
+ rotation = transform.Rotate(numpy.degrees(angle), *axis)
+ direction = rotation.transformDir(self._startExtrinsic.direction)
+ up = rotation.transformDir(self._startExtrinsic.up)
+
+ # Rotate position around center
+ trlist = transform.StaticTransformList((
+ self._center,
+ rotation,
+ self._center.inverse()))
+ position = trlist.transformPoint(self._startExtrinsic.position)
+
+ camerapos = self._viewport.camera.extrinsic
+ camerapos.setOrientation(direction, up)
+ camerapos.position = position
+
+ def endDrag(self, x, y):
+ self._reset()
+
+
+# CameraSelectPan #############################################################
+
+class CameraSelectPan(ClickOrDrag):
+ """Picking on click and pan camera on drag."""
+
+ def __init__(self, viewport, button=LEFT_BTN, selectCB=None):
+ self._viewport = viewport
+ self._selectCB = selectCB
+ self._lastPosNdc = None
+ super(CameraSelectPan, self).__init__(button)
+
+ def click(self, x, y):
+ if self._selectCB is not None:
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ position = self._viewport._getXZYGL(x, y)
+ # This assume no object lie on the far plane
+ # Alternative, change the depth range so that far is < 1
+ if ndcZ != 1. and position is not None:
+ self._selectCB((x, y, ndcZ), position)
+
+ def beginDrag(self, x, y):
+ ndc = self._viewport.windowToNdc(x, y)
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ # ndcZ is the panning plane
+ if ndc is not None and ndcZ is not None:
+ self._lastPosNdc = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
+ dtype=numpy.float32)
+ else:
+ self._lastPosNdc = None
+
+ def drag(self, x, y):
+ if self._lastPosNdc is not None:
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], self._lastPosNdc[2], 1.),
+ dtype=numpy.float32)
+
+ # Convert last and current NDC positions to scene coords
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ lastScenePos = self._viewport.camera.transformPoint(
+ self._lastPosNdc, direct=False, perspectiveDivide=True)
+
+ # Get translation in scene coords
+ translation = scenePos[:3] - lastScenePos[:3]
+ self._viewport.camera.extrinsic.position -= translation
+
+ # Store for next drag
+ self._lastPosNdc = ndcPos
+
+ def endDrag(self, x, y):
+ self._lastPosNdc = None
+
+
+# CameraWheel #################################################################
+
+class CameraWheel(object):
+ """StateMachine like class, just handling wheel events."""
+
+ # TODO choose scale of motion? Translation or Scale?
+ def __init__(self, viewport, mode='center', scaleTransform=None):
+ assert mode in ('center', 'position', 'scale')
+ self._viewport = viewport
+ if mode == 'center':
+ self._zoomTo = self._zoomToCenter
+ elif mode == 'position':
+ self._zoomTo = self._zoomToPosition
+ elif mode == 'scale':
+ self._zoomTo = self._zoomByScale
+ self._scale = scaleTransform
+ else:
+ raise ValueError('Unsupported mode: %s' % mode)
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ if eventName == 'wheel':
+ return self._zoomTo(*args, **kwargs)
+
+ def _zoomToCenter(self, x, y, angleInDegrees):
+ """Zoom to center of display.
+
+ Only works with perspective camera.
+ """
+ direction = 'forward' if angleInDegrees > 0 else 'backward'
+ self._viewport.camera.move(direction)
+ return True
+
+ def _zoomToPositionAbsolute(self, x, y, angleInDegrees):
+ """Zoom while keeping pixel under mouse invariant.
+
+ Only works with perspective camera.
+ """
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ near = numpy.array((ndc[0], ndc[1], -1., 1.), dtype=numpy.float32)
+
+ nearscene = self._viewport.camera.transformPoint(
+ near, direct=False, perspectiveDivide=True)
+
+ far = numpy.array((ndc[0], ndc[1], 1., 1.), dtype=numpy.float32)
+ farscene = self._viewport.camera.transformPoint(
+ far, direct=False, perspectiveDivide=True)
+
+ dirscene = farscene[:3] - nearscene[:3]
+ dirscene /= numpy.linalg.norm(dirscene)
+
+ if angleInDegrees < 0:
+ dirscene *= -1.
+
+ # TODO which scale
+ self._viewport.camera.extrinsic.position += dirscene
+ return True
+
+ def _zoomToPosition(self, x, y, angleInDegrees):
+ """Zoom while keeping pixel under mouse invariant."""
+ projection = self._viewport.camera.intrinsic
+ extrinsic = self._viewport.camera.extrinsic
+
+ if isinstance(projection, transform.Perspective):
+ # For perspective projection, move camera
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcz = self._viewport._pickNdcZGL(x, y)
+
+ position = numpy.array((ndc[0], ndc[1], ndcz),
+ dtype=numpy.float32)
+ positionscene = self._viewport.camera.transformPoint(
+ position, direct=False, perspectiveDivide=True)
+
+ camtopos = extrinsic.position - positionscene
+
+ step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+ extrinsic.position += step * camtopos
+
+ elif isinstance(projection, transform.Orthographic):
+ # For orthographic projection, change projection borders
+ ndcx, ndcy = self._viewport.windowToNdc(x, y, checkInside=False)
+
+ step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+
+ dx = (ndcx + 1) / 2.
+ stepwidth = step * (projection.right - projection.left)
+ left = projection.left - dx * stepwidth
+ right = projection.right + (1. - dx) * stepwidth
+
+ dy = (ndcy + 1) / 2.
+ stepheight = step * (projection.top - projection.bottom)
+ bottom = projection.bottom - dy * stepheight
+ top = projection.top + (1. - dy) * stepheight
+
+ projection.setClipping(left, right, bottom, top)
+
+ else:
+ raise RuntimeError('Unsupported camera', projection)
+ return True
+
+ def _zoomByScale(self, x, y, angleInDegrees):
+ """Zoom by scaling scene (do not keep pixel under mouse invariant)."""
+ scalefactor = 1.1
+ if angleInDegrees < 0.:
+ scalefactor = 1. / scalefactor
+ self._scale.scale = scalefactor * self._scale.scale
+
+ self._viewport.adjustCameraDepthExtent()
+ return True
+
+
+# FocusManager ################################################################
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ for eventHandler in self.machine.eventHandlers:
+ requestfocus = eventHandler.handleEvent('press', x, y, btn)
+ if requestfocus:
+ self.goto('focus', eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.eventHandlers:
+ consumeevent = eventHandler.handleEvent(*args)
+ if consumeevent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self._processEvent('release', x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent('wheel', x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn} # Set
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onPress(self, x, y, btn):
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent('press', x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self.focusBtns.discard(btn)
+ requestfocus = self.eventHandler.handleEvent('release', x, y, btn)
+ if len(self.focusBtns) == 0 and not requestfocus:
+ self.goto('idle')
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=()):
+ self.eventHandlers = list(eventHandlers)
+
+ states = {
+ 'idle': FocusManager.Idle,
+ 'focus': FocusManager.Focus
+ }
+ super(FocusManager, self).__init__(states, 'idle')
+
+ def cancel(self):
+ for handler in self.eventHandlers:
+ handler.cancel()
+
+
+# CameraControl ###############################################################
+
+class CameraControl(FocusManager):
+ """Combine wheel, selectPan and rotate state machine."""
+ def __init__(self, viewport,
+ orbitAroundCenter=False,
+ mode='center', scaleTransform=None,
+ selectCB=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB),
+ CameraRotate(viewport, orbitAroundCenter, RIGHT_BTN))
+ super(CameraControl, self).__init__(handlers)
+
+
+# PlaneRotate #################################################################
+
+class PlaneRotate(ClickOrDrag):
+ """Plane rotation using arcball interaction.
+
+ Arcball ref.:
+ Ken Shoemake. ARCBALL: A user interface for specifying three-dimensional
+ orientation using a mouse. In Proc. GI '92. (1992). pp. 151-156.
+ """
+
+ def __init__(self, viewport, plane, button=RIGHT_BTN):
+ self._viewport = viewport
+ self._plane = plane
+ self._reset()
+ super(PlaneRotate, self).__init__(button)
+
+ def _reset(self):
+ self._beginNormal, self._beginCenter = None, None
+
+ def click(self, x, y):
+ pass # No interaction
+
+ @staticmethod
+ def _sphereUnitVector(radius, center, position):
+ """Returns the unit vector of the projection of position on a sphere.
+
+ It assumes an orthographic projection.
+ For perspective projection, it gives an approximation, but it
+ simplifies computations and results in consistent arcball control
+ in control space.
+
+ All parameters must be in screen coordinate system
+ (either pixels or normalized coordinates).
+
+ :param float radius: The radius of the sphere.
+ :param center: (x, y) coordinates of the center.
+ :param position: (x, y) coordinates of the cursor position.
+ :return: Unit vector.
+ :rtype: numpy.ndarray of 3 floats.
+ """
+ center, position = numpy.array(center), numpy.array(position)
+
+ # Normalize x and y on a unit circle
+ spherecoords = (position - center) / float(radius)
+ squarelength = numpy.sum(spherecoords ** 2)
+
+ # Project on the unit sphere and compute z coordinates
+ if squarelength > 1.0: # Outside sphere: project
+ spherecoords /= numpy.sqrt(squarelength)
+ zsphere = 0.0
+ else: # In sphere: compute z
+ zsphere = numpy.sqrt(1. - squarelength)
+
+ spherecoords = numpy.append(spherecoords, zsphere)
+ return spherecoords
+
+ def beginDrag(self, x, y):
+ # Makes sure the point defining the plane is at the center as
+ # it will be the center of rotation (as rotation is applied to normal)
+ self._plane.plane.point = self._plane.center
+
+ # Store the plane normal
+ self._beginNormal = self._plane.plane.normal
+
+ _logger.debug(
+ 'Begin arcball, plane center %s', str(self._plane.center))
+
+ # Do the arcball on the screen
+ radius = min(self._viewport.size)
+ if self._plane.center is None:
+ self._beginCenter = None
+
+ else:
+ center = self._plane.objectToNDCTransform.transformPoint(
+ self._plane.center, perspectiveDivide=True)
+ self._beginCenter = self._viewport.ndcToWindow(
+ center[0], center[1], checkInside=False)
+
+ self._startVector = self._sphereUnitVector(
+ radius, self._beginCenter, (x, y))
+
+ def drag(self, x, y):
+ if self._beginCenter is None:
+ return
+
+ # Compute rotation: this is twice the rotation of the arcball
+ radius = min(self._viewport.size)
+ currentvector = self._sphereUnitVector(
+ radius, self._beginCenter, (x, y))
+ crossprod = numpy.cross(self._startVector, currentvector)
+ dotprod = numpy.dot(self._startVector, currentvector)
+
+ quaternion = numpy.append(crossprod, dotprod)
+ # Rotation was computed with Y downward, but apply in NDC, invert Y
+ quaternion[1] *= -1.
+
+ rotation = transform.Rotate()
+ rotation.quaternion = quaternion
+
+ # Convert to NDC, rotate, convert back to object
+ normal = self._plane.objectToNDCTransform.transformNormal(
+ self._beginNormal)
+ normal = rotation.transformNormal(normal)
+ normal = self._plane.objectToNDCTransform.transformNormal(
+ normal, direct=False)
+ self._plane.plane.normal = normal
+
+ def endDrag(self, x, y):
+ self._reset()
+
+
+# PlanePan ###################################################################
+
+class PlanePan(ClickOrDrag):
+ """Pan a plane along its normal on drag."""
+
+ def __init__(self, viewport, plane, button=LEFT_BTN):
+ self._plane = plane
+ self._viewport = viewport
+ self._beginPlanePoint = None
+ self._beginPos = None
+ self._dragNdcZ = 0.
+ super(PlanePan, self).__init__(button)
+
+ def click(self, x, y):
+ pass
+
+ def beginDrag(self, x, y):
+ ndc = self._viewport.windowToNdc(x, y)
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ # ndcZ is the panning plane
+ if ndc is not None and ndcZ is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
+ dtype=numpy.float32)
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ self._beginPos = self._plane.objectToSceneTransform.transformPoint(
+ scenePos, direct=False)
+ self._dragNdcZ = ndcZ
+ else:
+ self._beginPos = None
+ self._dragNdcZ = 0.
+
+ self._beginPlanePoint = self._plane.plane.point
+
+ def drag(self, x, y):
+ if self._beginPos is not None:
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], self._dragNdcZ, 1.),
+ dtype=numpy.float32)
+
+ # Convert last and current NDC positions to scene coords
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ curPos = self._plane.objectToSceneTransform.transformPoint(
+ scenePos, direct=False)
+
+ # Get translation in scene coords
+ translation = curPos[:3] - self._beginPos[:3]
+
+ newPoint = self._beginPlanePoint + translation
+
+ # Keep plane point in bounds
+ bounds = self._plane.parent.bounds(dataBounds=True)
+ if bounds is not None:
+ newPoint = numpy.clip(
+ newPoint, a_min=bounds[0], a_max=bounds[1])
+
+ # Only update plane if it is in some bounds
+ self._plane.plane.point = newPoint
+
+ def endDrag(self, x, y):
+ self._beginPlanePoint = None
+
+
+# PlaneControl ################################################################
+
+class PlaneControl(FocusManager):
+ """Combine wheel, selectPan and rotate state machine for plane control."""
+ def __init__(self, viewport, plane,
+ mode='center', scaleTransform=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ PlaneRotate(viewport, plane, RIGHT_BTN))
+ super(PlaneControl, self).__init__(handlers)
+
+
+class PanPlaneRotateCameraControl(FocusManager):
+ """Combine wheel, pan plane and camera rotate state machine."""
+ def __init__(self, viewport, plane,
+ mode='center', scaleTransform=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ CameraRotate(viewport,
+ orbitAroundCenter=False,
+ button=RIGHT_BTN))
+ super(PanPlaneRotateCameraControl, self).__init__(handlers)
diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py
new file mode 100644
index 0000000..ca2616a
--- /dev/null
+++ b/silx/gui/plot3d/scene/primitives.py
@@ -0,0 +1,1764 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import collections
+import ctypes
+from functools import reduce
+import logging
+import string
+
+import numpy
+
+from silx.gui.plot.Colors import rgba
+
+from ... import _glutils
+from ..._glutils import gl
+
+from . import event
+from . import core
+from . import transform
+from . import utils
+
+_logger = logging.getLogger(__name__)
+
+
+# Geometry ####################################################################
+
+class Geometry(core.Elem):
+ """Set of vertices with normals and colors.
+
+ :param str mode: OpenGL drawing mode:
+ lines, line_strip, loop, triangles, triangle_strip, fan
+ :param indices: Array of vertex indices or None
+ :param bool copy: True (default) to copy the data, False to use as is.
+ :param attributes: Provide list of attributes as extra parameters.
+ """
+
+ _ATTR_INFO = {
+ 'position': {'dims': (1, 2), 'lastDim': (2, 3, 4)},
+ 'normal': {'dims': (1, 2), 'lastDim': (3,)},
+ 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ }
+
+ _MODE_CHECKS = { # Min, Modulo
+ 'lines': (2, 2), 'line_strip': (2, 0), 'loop': (2, 0),
+ 'points': (1, 0),
+ 'triangles': (3, 3), 'triangle_strip': (3, 0), 'fan': (3, 0)
+ }
+
+ _MODES = {
+ 'lines': gl.GL_LINES,
+ 'line_strip': gl.GL_LINE_STRIP,
+ 'loop': gl.GL_LINE_LOOP,
+
+ 'points': gl.GL_POINTS,
+
+ 'triangles': gl.GL_TRIANGLES,
+ 'triangle_strip': gl.GL_TRIANGLE_STRIP,
+ 'fan': gl.GL_TRIANGLE_FAN
+ }
+
+ _LINE_MODES = 'lines', 'line_strip', 'loop'
+
+ _TRIANGLE_MODES = 'triangles', 'triangle_strip', 'fan'
+
+ def __init__(self, mode, indices=None, copy=True, **attributes):
+ super(Geometry, self).__init__()
+
+ self._vbos = {} # Store current vbos
+ self._unsyncAttributes = [] # Store attributes to copy to vbos
+ self.__bounds = None # Cache object's bounds
+
+ assert mode in self._MODES
+ self._mode = mode
+
+ # Set attributes
+ self._attributes = {}
+ for name, data in attributes.items():
+ self.setAttribute(name, data, copy=copy)
+
+ # Set indices
+ self._indices = None
+ self.setIndices(indices, copy=copy)
+
+ # More consistency checks
+ mincheck, modulocheck = self._MODE_CHECKS[self._mode]
+ if self._indices is not None:
+ nbvertices = len(self._indices)
+ else:
+ nbvertices = self.nbVertices
+ assert nbvertices >= mincheck
+ if modulocheck != 0:
+ assert (nbvertices % modulocheck) == 0
+
+ @staticmethod
+ def _glReadyArray(array, copy=True):
+ """Making a contiguous array, checking float types.
+
+ :param iterable array: array-like data to prepare for attribute
+ :param bool copy: True to make a copy of the array, False to use as is
+ """
+ # Convert single value (int, float, numpy types) to tuple
+ if not isinstance(array, collections.Iterable):
+ array = (array, )
+
+ # Makes sure it is an array
+ array = numpy.array(array, copy=False)
+
+ # Cast all float to float32
+ dtype = None
+ if numpy.dtype(array.dtype).kind == 'f':
+ dtype = numpy.float32
+
+ return numpy.array(array, dtype=dtype, order='C', copy=copy)
+
+ @property
+ def nbVertices(self):
+ """Returns the number of vertices of current attributes.
+
+ It returns None if there is no attributes.
+ """
+ for array in self._attributes.values():
+ if len(array.shape) == 2:
+ return len(array)
+ return None
+
+ def setAttribute(self, name, array, copy=True):
+ """Set attribute with provided array.
+
+ :param str name: The name of the attribute
+ :param array: Array-like attribute data or None to remove attribute
+ :param bool copy: True (default) to copy the data, False to use as is
+ """
+ # This triggers associated GL resources to be garbage collected
+ self._vbos.pop(name, None)
+
+ if array is None:
+ self._attributes.pop(name, None)
+
+ else:
+ array = self._glReadyArray(array, copy=copy)
+
+ if name not in self._ATTR_INFO:
+ _logger.info('Not checking attibute %s dimensions', name)
+ else:
+ checks = self._ATTR_INFO[name]
+
+ if (len(array.shape) == 1 and checks['lastDim'] == (1,) and
+ len(array) > 1):
+ array = array.reshape((len(array), 1))
+
+ # Checks
+ assert len(array.shape) in checks['dims'], "Attr %s" % name
+ assert array.shape[-1] in checks['lastDim'], "Attr %s" % name
+
+ # Check length against another attribute array
+ # Causes problems when updating
+ # nbVertices = self.nbVertices
+ # if len(array.shape) == 2 and nbVertices is not None:
+ # assert len(array) == nbVertices
+
+ self._attributes[name] = array
+ if len(array.shape) == 2: # Store this in a VBO
+ self._unsyncAttributes.append(name)
+
+ if name == 'position': # Reset bounds
+ self.__bounds = None
+
+ self.notify()
+
+ def getAttribute(self, name, copy=True):
+ """Returns the numpy.ndarray corresponding to the name attribute.
+
+ :param str name: The name of the attribute to get.
+ :param bool copy: True to get a copy (default),
+ False to get internal array (DO NOT MODIFY)
+ :return: The corresponding array or None if no corresponding attribute.
+ :rtype: numpy.ndarray
+ """
+ attr = self._attributes.get(name, None)
+ return None if attr is None else numpy.array(attr, copy=copy)
+
+ def useAttribute(self, program, name=None):
+ """Enable and bind attribute(s) for a specific program.
+
+ This MUST be called with OpenGL context active and after prepareGL2
+ has been called.
+
+ :param GLProgram program: The program for which to set the attributes
+ :param str name: The attribute name to set or None to set then all
+ """
+ if name is None:
+ for name in program.attributes:
+ self.useAttribute(program, name)
+
+ else:
+ attribute = program.attributes.get(name)
+ if attribute is None:
+ return
+
+ vboattrib = self._vbos.get(name)
+ if vboattrib is not None:
+ gl.glEnableVertexAttribArray(attribute)
+ vboattrib.setVertexAttrib(attribute)
+
+ elif name not in self._attributes:
+ gl.glDisableVertexAttribArray(attribute)
+
+ else:
+ array = self._attributes[name]
+ assert array is not None
+
+ if len(array.shape) == 1:
+ assert len(array) in (1, 2, 3, 4)
+ gl.glDisableVertexAttribArray(attribute)
+ _glVertexAttribFunc = getattr(
+ _glutils.gl, 'glVertexAttrib{}f'.format(len(array)))
+ _glVertexAttribFunc(attribute, *array)
+ else:
+ # TODO As is this is a never event, remove?
+ gl.glEnableVertexAttribArray(attribute)
+ gl.glVertexAttribPointer(
+ attribute,
+ array.shape[-1],
+ _glutils.numpyToGLType(array.dtype),
+ gl.GL_FALSE,
+ 0,
+ array)
+
+ def setIndices(self, indices, copy=True):
+ """Set the primitive indices to use.
+
+ :param indices: Array-like of uint primitive indices or None to unset
+ :param bool copy: True (default) to copy the data, False to use as is
+ """
+ # Trigger garbage collection of previous indices VBO if any
+ self._vbos.pop('__indices__', None)
+
+ if indices is None:
+ self._indices = None
+ else:
+ indices = self._glReadyArray(indices, copy=copy).ravel()
+ assert indices.dtype.name in ('uint8', 'uint16', 'uint32')
+ if _logger.getEffectiveLevel() <= logging.DEBUG:
+ # This might be a costy check
+ assert indices.max() < self.nbVertices
+ self._indices = indices
+
+ def getIndices(self, copy=True):
+ """Returns the numpy.ndarray corresponding to the indices.
+
+ :param bool copy: True to get a copy (default),
+ False to get internal array (DO NOT MODIFY)
+ :return: The primitive indices array or None if not set.
+ :rtype: numpy.ndarray or None
+ """
+ if self._indices is None:
+ return None
+ else:
+ return numpy.array(self._indices, copy=copy)
+
+ def _bounds(self, dataBounds=False):
+ if self.__bounds is None:
+ self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+ # Support vertex with to 2 to 4 coordinates
+ positions = self._attributes['position']
+ self.__bounds[0, :positions.shape[1]] = positions.min(axis=0)[:3]
+ self.__bounds[1, :positions.shape[1]] = positions.max(axis=0)[:3]
+ return self.__bounds.copy()
+
+ def prepareGL2(self, ctx):
+ # TODO manage _vbo and multiple GL context + allow to share them !
+ # TODO make one or multiple VBO depending on len(vertices),
+ # TODO use a general common VBO for small amount of data
+ for name in self._unsyncAttributes:
+ array = self._attributes[name]
+ self._vbos[name] = ctx.glCtx.makeVboAttrib(array)
+ self._unsyncAttributes = []
+
+ if self._indices is not None and '__indices__' not in self._vbos:
+ vbo = ctx.glCtx.makeVbo(self._indices,
+ usage=gl.GL_STATIC_DRAW,
+ target=gl.GL_ELEMENT_ARRAY_BUFFER)
+ self._vbos['__indices__'] = vbo
+
+ def _draw(self, program=None, nbVertices=None):
+ """Perform OpenGL draw calls.
+
+ :param GLProgram program:
+ If not None, call :meth:`useAttribute` for this program.
+ :param int nbVertices:
+ The number of vertices to render or None to render all vertices.
+ """
+ if program is not None:
+ self.useAttribute(program)
+
+ if self._indices is None:
+ if nbVertices is None:
+ nbVertices = self.nbVertices
+ gl.glDrawArrays(self._MODES[self._mode], 0, nbVertices)
+ else:
+ if nbVertices is None:
+ nbVertices = self._indices.size
+ with self._vbos['__indices__']:
+ gl.glDrawElements(self._MODES[self._mode],
+ nbVertices,
+ _glutils.numpyToGLType(self._indices.dtype),
+ ctypes.c_void_p(0))
+
+
+# Lines #######################################################################
+
+class Lines(Geometry):
+ """A set of segments"""
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ void main(void)
+ {
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ $clippingDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+ }
+ """))
+
+ def __init__(self, positions, normals=None, colors=(1., 1., 1., 1.),
+ indices=None, mode='lines', width=1.):
+ if mode == 'strip':
+ mode = 'line_strip'
+ assert mode in self._LINE_MODES
+
+ self._width = width
+ self._smooth = True
+
+ super(Lines, self).__init__(mode, indices,
+ position=positions,
+ normal=normals,
+ color=colors)
+
+ width = event.notifyProperty('_width', converter=float,
+ doc="Width of the line in pixels.")
+
+ smooth = event.notifyProperty(
+ '_smooth',
+ converter=bool,
+ doc="Smooth line rendering enabled (bool, default: True)")
+
+ def renderGL2(self, ctx):
+ # Prepare program
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fraglightfunction = ctx.viewport.light.fragmentDef
+ else:
+ fraglightfunction = ctx.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall,
+ lightingFunction=fraglightfunction,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ if isnormals:
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ gl.glLineWidth(self.width)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ with gl.enabled(gl.GL_LINE_SMOOTH, self._smooth):
+ self._draw(prog)
+
+
+class DashedLines(Lines):
+ """Set of dashed lines
+
+ This MUST be defined as a set of lines (no strip or loop).
+ """
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 origin;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ uniform vec2 viewportSize; /* Width, height of the viewport */
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+ varying vec2 vOriginFragCoord;
+
+ void main(void)
+ {
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+
+ vec4 clipOrigin = matrix * vec4(origin, 1.0);
+ vec4 ndcOrigin = clipOrigin / clipOrigin.w; /* Perspective divide */
+ /* Convert to same frame as gl_FragCoord: lower-left, pixel center at 0.5, 0.5 */
+ vOriginFragCoord = (ndcOrigin.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize + vec2(0.5, 0.5);
+ }
+ """, # noqa
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+ varying vec2 vOriginFragCoord;
+
+ uniform vec2 dash;
+
+ $clippingDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ /* Discard off dash fragments */
+ float lineDist = distance(vOriginFragCoord, gl_FragCoord.xy);
+ if (mod(lineDist, dash.x + dash.y) > dash.x) {
+ discard;
+ }
+ $clippingCall(vCameraPosition);
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+ }
+ """))
+
+ def __init__(self, positions, colors=(1., 1., 1., 1.),
+ indices=None, width=1.):
+ self._dash = 1, 0
+ super(DashedLines, self).__init__(positions=positions,
+ colors=colors,
+ indices=indices,
+ mode='lines',
+ width=width)
+
+ @property
+ def dash(self):
+ """Dash of the line as a 2-tuple of lengths in pixels: (on, off)"""
+ return self._dash
+
+ @dash.setter
+ def dash(self, dash):
+ dash = float(dash[0]), float(dash[1])
+ if dash != self._dash:
+ self._dash = dash
+ self.notify()
+
+ def getPositions(self, copy=True):
+ """Get coordinates of lines.
+
+ :param bool copy: True to get a copy, False otherwise
+ :returns: Coordinates of lines
+ :rtype: numpy.ndarray of float32 of shape (N, 2, Ndim)
+ """
+ return self.getAttribute('position', copy=copy)
+
+ def setPositions(self, positions, copy=True):
+ """Set line coordinates.
+
+ :param positions: Array of line coordinates
+ :param bool copy: True to copy input array, False to use as is
+ """
+ self.setAttribute('position', positions, copy=copy)
+ # Update line origins from given positions
+ origins = numpy.array(positions, copy=True, order='C')
+ origins[1::2] = origins[::2]
+ self.setAttribute('origin', origins, copy=False)
+
+ def renderGL2(self, context):
+ # Prepare program
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fraglightfunction = context.viewport.light.fragmentDef
+ else:
+ fraglightfunction = \
+ context.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ clippingDecl=context.clipper.fragDecl,
+ clippingCall=context.clipper.fragCall,
+ lightingFunction=fraglightfunction,
+ lightingCall=context.viewport.light.fragmentCall)
+ program = context.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ if isnormals:
+ context.viewport.light.setupProgram(context, program)
+
+ gl.glLineWidth(self.width)
+
+ program.setUniformMatrix('matrix', context.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ context.objectToCamera.matrix,
+ safe=True)
+
+ gl.glUniform2f(
+ program.uniforms['viewportSize'], *context.viewport.size)
+ gl.glUniform2f(program.uniforms['dash'], *self.dash)
+
+ context.clipper.setupProgram(context, program)
+
+ self._draw(program)
+
+
+class Box(core.PrivateGroup):
+ """Rectangular box"""
+
+ _lineIndices = numpy.array((
+ (0, 1), (1, 2), (2, 3), (3, 0), # Lines with z=0
+ (0, 4), (1, 5), (2, 6), (3, 7), # Lines from z=0 to z=1
+ (4, 5), (5, 6), (6, 7), (7, 4)), # Lines with z=1
+ dtype=numpy.uint8)
+
+ _faceIndices = numpy.array(
+ (0, 3, 1, 2, 5, 6, 4, 7, 7, 6, 6, 2, 7, 3, 4, 0, 5, 1),
+ dtype=numpy.uint8)
+
+ _vertices = numpy.array((
+ # Corners with z=0
+ (0., 0., 0.), (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
+ # Corners with z=1
+ (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
+ dtype=numpy.float32)
+
+ def __init__(self, size=(1., 1., 1.),
+ stroke=(1., 1., 1., 1.),
+ fill=(1., 1., 1., 0.)):
+ super(Box, self).__init__()
+
+ self._fill = Mesh3D(self._vertices,
+ colors=rgba(fill),
+ mode='triangle_strip',
+ indices=self._faceIndices)
+ self._fill.visible = self.fillColor[-1] != 0.
+
+ self._stroke = Lines(self._vertices,
+ indices=self._lineIndices,
+ colors=rgba(stroke),
+ mode='lines')
+ self._stroke.visible = self.strokeColor[-1] != 0.
+ self.strokeWidth = 1.
+
+ self._children = [self._stroke, self._fill]
+
+ self._size = None
+ self.size = size
+
+ @property
+ def size(self):
+ """Size of the box (sx, sy, sz)"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 3
+ size = tuple(size)
+ if size != self.size:
+ self._size = size
+ self._fill.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self._stroke.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self.notify()
+
+ @property
+ def strokeSmooth(self):
+ """True to draw smooth stroke, False otherwise"""
+ return self._stroke.smooth
+
+ @strokeSmooth.setter
+ def strokeSmooth(self, smooth):
+ smooth = bool(smooth)
+ if smooth != self._stroke.smooth:
+ self._stroke.smooth = smooth
+ self.notify()
+
+ @property
+ def strokeWidth(self):
+ """Width of the stroke (float)"""
+ return self._stroke.width
+
+ @strokeWidth.setter
+ def strokeWidth(self, width):
+ width = float(width)
+ if width != self.strokeWidth:
+ self._stroke.width = width
+ self.notify()
+
+ @property
+ def strokeColor(self):
+ """RGBA color of the box lines (4-tuple of float in [0, 1])"""
+ return tuple(self._stroke.getAttribute('color', copy=False))
+
+ @strokeColor.setter
+ def strokeColor(self, color):
+ color = rgba(color)
+ if color != self.strokeColor:
+ self._stroke.setAttribute('color', color)
+ # Fully transparent = hidden
+ self._stroke.visible = color[-1] != 0.
+ self.notify()
+
+ @property
+ def fillColor(self):
+ """RGBA color of the box faces (4-tuple of float in [0, 1])"""
+ return tuple(self._fill.getAttribute('color', copy=False))
+
+ @fillColor.setter
+ def fillColor(self, color):
+ color = rgba(color)
+ if color != self.fillColor:
+ self._fill.setAttribute('color', color)
+ # Fully transparent = hidden
+ self._fill.visible = color[-1] != 0.
+ self.notify()
+
+ @property
+ def fillCulling(self):
+ return self._fill.culling
+
+ @fillCulling.setter
+ def fillCulling(self, culling):
+ self._fill.culling = culling
+
+
+class Axes(Lines):
+ """3D RGB orthogonal axes"""
+ _vertices = numpy.array(((0., 0., 0.), (1., 0., 0.),
+ (0., 0., 0.), (0., 1., 0.),
+ (0., 0., 0.), (0., 0., 1.)),
+ dtype=numpy.float32)
+
+ _colors = numpy.array(((255, 0, 0, 255), (255, 0, 0, 255),
+ (0, 255, 0, 255), (0, 255, 0, 255),
+ (0, 0, 255, 255), (0, 0, 255, 255)),
+ dtype=numpy.uint8)
+
+ def __init__(self):
+ super(Axes, self).__init__(self._vertices,
+ colors=self._colors,
+ width=3.)
+
+
+class BoxWithAxes(Lines):
+ """Rectangular box with RGB OX, OY, OZ axes
+
+ :param color: RGBA color of the box
+ """
+
+ _vertices = numpy.array((
+ # Axes corners
+ (0., 0., 0.), (1., 0., 0.),
+ (0., 0., 0.), (0., 1., 0.),
+ (0., 0., 0.), (0., 0., 1.),
+ # Box corners with z=0
+ (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
+ # Box corners with z=1
+ (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
+ dtype=numpy.float32)
+
+ _axesColors = numpy.array(((1., 0., 0., 1.), (1., 0., 0., 1.),
+ (0., 1., 0., 1.), (0., 1., 0., 1.),
+ (0., 0., 1., 1.), (0., 0., 1., 1.)),
+ dtype=numpy.float32)
+
+ _lineIndices = numpy.array((
+ (0, 1), (2, 3), (4, 5), # Axes lines
+ (6, 7), (7, 8), # Box lines with z=0
+ (6, 10), (7, 11), (8, 12), # Box lines from z=0 to z=1
+ (9, 10), (10, 11), (11, 12), (12, 9)), # Box lines with z=1
+ dtype=numpy.uint8)
+
+ def __init__(self, color=(1., 1., 1., 1.)):
+ self._color = (1., 1., 1., 1.)
+ colors = numpy.ones((len(self._vertices), 4), dtype=numpy.float32)
+ colors[:len(self._axesColors), :] = self._axesColors
+
+ super(BoxWithAxes, self).__init__(self._vertices,
+ indices=self._lineIndices,
+ colors=colors,
+ width=2.)
+ self.color = color
+
+ @property
+ def color(self):
+ """The RGBA color to use for the box: 4 float in [0, 1]"""
+ return self._color
+
+ @color.setter
+ def color(self, color):
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ colors = numpy.empty((len(self._vertices), 4), dtype=numpy.float32)
+ colors[:len(self._axesColors), :] = self._axesColors
+ colors[len(self._axesColors):, :] = self._color
+ self.setAttribute('color', colors) # Do the notification
+
+
+class PlaneInGroup(core.PrivateGroup):
+ """A plane using its parent bounds to display a contour.
+
+ If plane is outside the bounds of its parent, it is not visible.
+
+ Cannot set the transform attribute of this primitive.
+ This primitive never has any bounds.
+ """
+ # TODO inherit from Lines directly?, make sure the plane remains visible?
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ super(PlaneInGroup, self).__init__()
+ self._cache = None, None # Store bounds, vertices
+ self._outline = None
+
+ self._color = None
+ self.color = 1., 1., 1., 1. # Set _color
+ self._width = 2.
+
+ self._plane = utils.Plane(point, normal)
+ self._plane.addListener(self._planeChanged)
+
+ def moveToCenter(self):
+ """Place the plane at the center of the data, not changing orientation.
+ """
+ if self.parent is not None:
+ bounds = self.parent.bounds(dataBounds=True)
+ if bounds is not None:
+ center = (bounds[0] + bounds[1]) / 2.
+ _logger.debug('Moving plane to center: %s', str(center))
+ self.plane.point = center
+
+ @property
+ def color(self):
+ """Plane outline color (array of 4 float in [0, 1])."""
+ return self._color.copy()
+
+ @color.setter
+ def color(self, color):
+ self._color = numpy.array(color, copy=True, dtype=numpy.float32)
+ if self._outline is not None:
+ self._outline.setAttribute('color', self._color)
+ self.notify() # This is OK as Lines are rebuild for each rendering
+
+ @property
+ def width(self):
+ """Width of the plane stroke in pixels"""
+ return self._width
+
+ @width.setter
+ def width(self, width):
+ self._width = float(width)
+ if self._outline is not None:
+ self._outline.width = self._width # Sync width
+
+ # Plane access
+
+ @property
+ def plane(self):
+ """The plane parameters in the frame of the object."""
+ return self._plane
+
+ def _planeChanged(self, source):
+ """Listener of plane changes: clear cache and notify listeners."""
+ self._cache = None, None
+ self.notify()
+
+ # Disable some scene features
+
+ @property
+ def transforms(self):
+ # Ready-only transforms to prevent using it
+ return self._transforms
+
+ def _bounds(self, dataBounds=False):
+ # This is bound less as it uses the bounds of its parent.
+ return None
+
+ @property
+ def contourVertices(self):
+ """The vertices of the contour of the plane/bounds intersection."""
+ parent = self.parent
+ if parent is None:
+ return None # No parent: no vertices
+
+ bounds = parent.bounds(dataBounds=True)
+ if bounds is None:
+ return None # No bounds: no vertices
+
+ # Check if cache is valid and return it
+ cachebounds, cachevertices = self._cache
+ if numpy.all(numpy.equal(bounds, cachebounds)):
+ return cachevertices
+
+ # Cache is not OK, rebuild it
+ boxvertices = bounds[0] + Box._vertices.copy()*(bounds[1] - bounds[0])
+ lineindices = Box._lineIndices
+ vertices = utils.boxPlaneIntersect(
+ boxvertices, lineindices, self.plane.normal, self.plane.point)
+
+ self._cache = bounds, vertices if len(vertices) != 0 else None
+
+ return self._cache[1]
+
+ @property
+ def center(self):
+ """The center of the plane/bounds intersection points."""
+ if not self.isValid:
+ return None
+ else:
+ return numpy.mean(self.contourVertices, axis=0)
+
+ @property
+ def isValid(self):
+ """True if a contour is defined, False otherwise."""
+ return self.plane.isPlane and self.contourVertices is not None
+
+ def prepareGL2(self, ctx):
+ if self.isValid:
+ if self._outline is None: # Init outline
+ self._outline = Lines(self.contourVertices,
+ mode='loop',
+ colors=self.color)
+ self._outline.width = self._width
+ self._children.append(self._outline)
+
+ # Update vertices, TODO only when necessary
+ self._outline.setAttribute('position', self.contourVertices)
+
+ super(PlaneInGroup, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ if self.isValid:
+ super(PlaneInGroup, self).renderGL2(ctx)
+
+
+# Points ######################################################################
+
+_POINTS_ATTR_INFO = Geometry._ATTR_INFO.copy()
+_POINTS_ATTR_INFO.update(value={'dims': (1, 2), 'lastDim': (1,)},
+ size={'dims': (1, 2), 'lastDim': (1,)},
+ symbol={'dims': (1, 2), 'lastDim': (1,)})
+
+
+class Points(Geometry):
+ """A set of data points with an associated value and size."""
+ _shaders = ("""
+ #version 120
+
+ attribute vec3 position;
+ attribute float symbol;
+ attribute float value;
+ attribute float size;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+
+ uniform vec2 valRange;
+
+ varying vec4 vCameraPosition;
+ varying float vSymbol;
+ varying float vNormValue;
+ varying float vSize;
+
+ void main(void)
+ {
+ vSymbol = symbol;
+
+ vNormValue = clamp((value - valRange.x) / (valRange.y - valRange.x),
+ 0.0, 1.0);
+
+ bool isValueInRange = value >= valRange.x && value <= valRange.y;
+ if (isValueInRange) {
+ gl_Position = matrix * vec4(position, 1.0);
+ } else {
+ gl_Position = vec4(2.0, 0.0, 0.0, 1.0); /* Get clipped */
+ }
+ vCameraPosition = transformMat * vec4(position, 1.0);
+
+ gl_PointSize = size;
+ vSize = size;
+ }
+ """,
+ string.Template("""
+ #version 120
+
+ varying vec4 vCameraPosition;
+ varying float vSize;
+ varying float vSymbol;
+ varying float vNormValue;
+
+ $clippinDecl
+
+ /* Circle */
+ #define SYMBOL_CIRCLE 1.0
+
+ float alphaCircle(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+
+ /* Half lines */
+ #define SYMBOL_H_LINE 2.0
+ #define LEFT 1.0
+ #define RIGHT 2.0
+ #define SYMBOL_V_LINE 3.0
+ #define UP 1.0
+ #define DOWN 2.0
+
+ float alphaLine(vec2 coord, float size, float direction)
+ {
+ vec2 delta = abs(size * (coord - 0.5));
+
+ if (direction == SYMBOL_H_LINE) {
+ return (delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_H_LINE + LEFT) {
+ return (coord.x <= 0.5 && delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_H_LINE + RIGHT) {
+ return (coord.x >= 0.5 && delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE) {
+ return (delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE + UP) {
+ return (coord.y <= 0.5 && delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE + DOWN) {
+ return (coord.y >= 0.5 && delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ return 1.0;
+ }
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+
+ gl_FragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0);
+
+ float alpha = 1.0;
+ float symbol = floor(vSymbol);
+ if (1 == 1) { //symbol == SYMBOL_CIRCLE) {
+ alpha = alphaCircle(gl_PointCoord, vSize);
+ }
+ else if (symbol >= SYMBOL_H_LINE &&
+ symbol <= (SYMBOL_V_LINE + DOWN)) {
+ alpha = alphaLine(gl_PointCoord, vSize, symbol);
+ }
+ if (alpha == 0.0) {
+ discard;
+ }
+ gl_FragColor.a *= alpha;
+ }
+ """))
+
+ _ATTR_INFO = _POINTS_ATTR_INFO
+
+ # TODO Add colormap, light?
+
+ def __init__(self, vertices, values=0., sizes=1., indices=None,
+ symbols=0.,
+ minValue=None, maxValue=None):
+ super(Points, self).__init__('points', indices,
+ position=vertices,
+ value=values,
+ size=sizes,
+ symbol=symbols)
+
+ values = self._attributes['value']
+ self._minValue = values.min() if minValue is None else minValue
+ self._maxValue = values.max() if maxValue is None else maxValue
+
+ minValue = event.notifyProperty('_minValue')
+ maxValue = event.notifyProperty('_maxValue')
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue)
+
+ self._draw(prog)
+
+
+class ColorPoints(Geometry):
+ """A set of points with an associated color and size."""
+
+ _shaders = ("""
+ #version 120
+
+ attribute vec3 position;
+ attribute float symbol;
+ attribute vec4 color;
+ attribute float size;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+
+ varying vec4 vCameraPosition;
+ varying float vSymbol;
+ varying vec4 vColor;
+ varying float vSize;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ vSymbol = symbol;
+ vColor = color;
+ gl_Position = matrix * vec4(position, 1.0);
+ gl_PointSize = size;
+ vSize = size;
+ }
+ """,
+ string.Template("""
+ #version 120
+
+ varying vec4 vCameraPosition;
+ varying float vSize;
+ varying float vSymbol;
+ varying vec4 vColor;
+
+ $clippingDecl;
+
+ /* Circle */
+ #define SYMBOL_CIRCLE 1.0
+
+ float alphaCircle(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+
+ /* Half lines */
+ #define SYMBOL_H_LINE 2.0
+ #define LEFT 1.0
+ #define RIGHT 2.0
+ #define SYMBOL_V_LINE 3.0
+ #define UP 1.0
+ #define DOWN 2.0
+
+ float alphaLine(vec2 coord, float size, float direction)
+ {
+ vec2 delta = abs(size * (coord - 0.5));
+
+ if (direction == SYMBOL_H_LINE) {
+ return (delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_H_LINE + LEFT) {
+ return (coord.x <= 0.5 && delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_H_LINE + RIGHT) {
+ return (coord.x >= 0.5 && delta.y < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE) {
+ return (delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE + UP) {
+ return (coord.y <= 0.5 && delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ else if (direction == SYMBOL_V_LINE + DOWN) {
+ return (coord.y >= 0.5 && delta.x < 0.5) ? 1.0 : 0.0;
+ }
+ return 1.0;
+ }
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+
+ gl_FragColor = vColor;
+
+ float alpha = 1.0;
+ float symbol = floor(vSymbol);
+ if (1 == 1) { //symbol == SYMBOL_CIRCLE) {
+ alpha = alphaCircle(gl_PointCoord, vSize);
+ }
+ else if (symbol >= SYMBOL_H_LINE &&
+ symbol <= (SYMBOL_V_LINE + DOWN)) {
+ alpha = alphaLine(gl_PointCoord, vSize, symbol);
+ }
+ if (alpha == 0.0) {
+ discard;
+ }
+ gl_FragColor.a *= alpha;
+ }
+ """))
+
+ _ATTR_INFO = _POINTS_ATTR_INFO
+
+ def __init__(self, vertices, colors=(1., 1., 1., 1.), sizes=1.,
+ indices=None, symbols=0.,
+ minValue=None, maxValue=None):
+ super(ColorPoints, self).__init__('points', indices,
+ position=vertices,
+ color=colors,
+ size=sizes,
+ symbol=symbols)
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ self._draw(prog)
+
+
+class GridPoints(Geometry):
+ # GLSL 1.30 !
+ """Data points on a regular grid with an associated value and size."""
+ _shaders = ("""
+ #version 130
+
+ in float value;
+ in float size;
+
+ uniform ivec3 gridDims;
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ uniform vec2 valRange;
+
+ out vec4 vCameraPosition;
+ out float vNormValue;
+
+ //ivec3 coordsFromIndex(int index, ivec3 shape)
+ //{
+ /*Assumes that data is stored as z-major, then y, contiguous on x
+ */
+ // int yxPlaneSize = shape.y * shape.x; /* nb of elem in 2d yx plane */
+ // int z = index / yxPlaneSize;
+ // int yxIndex = index - z * yxPlaneSize; /* index in 2d yx plane */
+ // int y = yxIndex / shape.x;
+ // int x = yxIndex - y * shape.x;
+ // return ivec3(x, y, z);
+ // }
+
+ ivec3 coordsFromIndex(int index, ivec3 shape)
+ {
+ /*Assumes that data is stored as x-major, then y, contiguous on z
+ */
+ int yzPlaneSize = shape.y * shape.z; /* nb of elem in 2d yz plane */
+ int x = index / yzPlaneSize;
+ int yzIndex = index - x * yzPlaneSize; /* index in 2d yz plane */
+ int y = yzIndex / shape.z;
+ int z = yzIndex - y * shape.z;
+ return ivec3(x, y, z);
+ }
+
+ void main(void)
+ {
+ vNormValue = clamp((value - valRange.x) / (valRange.y - valRange.x),
+ 0.0, 1.0);
+
+ bool isValueInRange = value >= valRange.x && value <= valRange.y;
+ if (isValueInRange) {
+ /* Retrieve 3D position from gridIndex */
+ vec3 coords = vec3(coordsFromIndex(gl_VertexID, gridDims));
+ vec3 position = coords / max(vec3(gridDims) - 1.0, 1.0);
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ } else {
+ gl_Position = vec4(2.0, 0.0, 0.0, 1.0); /* Get clipped */
+ vCameraPosition = vec4(0.0, 0.0, 0.0, 0.0);
+ }
+
+ gl_PointSize = size;
+ }
+ """,
+ string.Template("""
+ #version 130
+
+ in vec4 vCameraPosition;
+ in float vNormValue;
+ out vec4 fragColor;
+
+ $clippingDecl
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+
+ fragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0);
+ }
+ """))
+
+ _ATTR_INFO = {
+ 'value': {'dims': (1, 2), 'lastDim': (1,)},
+ 'size': {'dims': (1, 2), 'lastDim': (1,)}
+ }
+
+ # TODO Add colormap, shape?
+ # TODO could also use a texture to store values
+
+ def __init__(self, values=0., shape=None, sizes=1., indices=None,
+ minValue=None, maxValue=None):
+ if isinstance(values, collections.Iterable):
+ values = numpy.array(values, copy=False)
+
+ # Test if gl_VertexID will overflow
+ assert values.size < numpy.iinfo(numpy.int32).max
+
+ self._shape = values.shape
+ values = values.ravel() # 1D to add as a 1D vertex attribute
+
+ else:
+ assert shape is not None
+ self._shape = tuple(shape)
+
+ assert len(self._shape) in (1, 2, 3)
+
+ super(GridPoints, self).__init__('points', indices,
+ value=values,
+ size=sizes)
+
+ data = self.getAttribute('value', copy=False)
+ self._minValue = data.min() if minValue is None else minValue
+ self._maxValue = data.max() if maxValue is None else maxValue
+
+ minValue = event.notifyProperty('_minValue')
+ maxValue = event.notifyProperty('_maxValue')
+
+ def _bounds(self, dataBounds=False):
+ # Get bounds from values shape
+ bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+ bounds[1, :] = self._shape
+ bounds[1, :] -= 1
+ return bounds
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ gl.glUniform3i(prog.uniforms['gridDims'],
+ self._shape[2] if len(self._shape) == 3 else 1,
+ self._shape[1] if len(self._shape) >= 2 else 1,
+ self._shape[0])
+
+ gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue)
+
+ self._draw(prog, nbVertices=reduce(lambda a, b: a * b, self._shape))
+
+
+# Spheres #####################################################################
+
+class Spheres(Geometry):
+ """A set of spheres.
+
+ Spheres are rendered as circles using points.
+ This brings some limitations:
+ - Do not support non-uniform scaling.
+ - Assume the projection keeps ratio.
+ - Do not render distorion by perspective projection.
+ - If the sphere center is clipped, the whole sphere is not displayed.
+ """
+ # TODO check those links
+ # Accounting for perspective projection
+ # http://iquilezles.org/www/articles/sphereproj/sphereproj.htm
+
+ # Michael Mara and Morgan McGuire.
+ # 2D Polyhedral Bounds of a Clipped, Perspective-Projected 3D Sphere
+ # Journal of Computer Graphics Techniques, Vol. 2, No. 2, 2013.
+ # http://jcgt.org/published/0002/02/05/paper.pdf
+ # https://research.nvidia.com/publication/2d-polyhedral-bounds-clipped-perspective-projected-3d-sphere
+
+ # TODO some issues with small scaling and regular grid or due to sampling
+
+ _shaders = ("""
+ #version 120
+
+ attribute vec3 position;
+ attribute vec4 color;
+ attribute float radius;
+
+ uniform mat4 transformMat;
+ uniform mat4 projMat;
+ uniform vec2 screenSize;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec4 vColor;
+ varying float vViewDepth;
+ varying float vViewRadius;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ gl_Position = projMat * vCameraPosition;
+
+ vPosition = gl_Position.xyz / gl_Position.w;
+
+ /* From object space radius to view space diameter.
+ * Do not support non-uniform scaling */
+ vec4 viewSizeVector = transformMat * vec4(2.0 * radius, 0.0, 0.0, 0.0);
+ float viewSize = length(viewSizeVector.xyz);
+
+ /* Convert to pixel size at the xy center of the view space */
+ vec4 projSize = projMat * vec4(0.5 * viewSize, 0.0,
+ vCameraPosition.z, vCameraPosition.w);
+ gl_PointSize = max(1.0, screenSize[0] * projSize.x / projSize.w);
+
+ vColor = color;
+ vViewRadius = 0.5 * viewSize;
+ vViewDepth = vCameraPosition.z;
+ }
+ """,
+ string.Template("""
+ # version 120
+
+ uniform mat4 projMat;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec4 vColor;
+ varying float vViewDepth;
+ varying float vViewRadius;
+
+ $clippingDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+
+ /* Get normal from point coords */
+ vec3 normal;
+ normal.xy = 2.0 * gl_PointCoord - vec2(1.0);
+ normal.y *= -1.0; /*Invert y to match NDC orientation*/
+ float sqLength = dot(normal.xy, normal.xy);
+ if (sqLength > 1.0) { /* Length -> out of sphere */
+ discard;
+ }
+ normal.z = sqrt(1.0 - sqLength);
+
+ /*Lighting performed in NDC*/
+ /*TODO update this when lighting changed*/
+ //XXX vec3 position = vPosition + vViewRadius * normal;
+ gl_FragColor = $lightingCall(vColor, vPosition, normal);
+
+ /*Offset depth*/
+ float viewDepth = vViewDepth + vViewRadius * normal.z;
+ vec2 clipZW = viewDepth * projMat[2].zw + projMat[3].zw;
+ gl_FragDepth = 0.5 * (clipZW.x / clipZW.y) + 0.5;
+ }
+ """))
+
+ _ATTR_INFO = {
+ 'position': {'dims': (2, ), 'lastDim': (2, 3, 4)},
+ 'radius': {'dims': (1, 2), 'lastDim': (1, )},
+ 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ }
+
+ def __init__(self, positions, radius=1., colors=(1., 1., 1., 1.)):
+ self.__bounds = None
+ super(Spheres, self).__init__('points', None,
+ position=positions,
+ radius=radius,
+ color=colors)
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('projMat', ctx.projection.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ gl.glUniform2f(prog.uniforms['screenSize'], *ctx.viewport.size)
+
+ self._draw(prog)
+
+ def _bounds(self, dataBounds=False):
+ if self.__bounds is None:
+ self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+ # Support vertex with to 2 to 4 coordinates
+ positions = self._attributes['position']
+ radius = self._attributes['radius']
+ self.__bounds[0, :positions.shape[1]] = \
+ (positions - radius).min(axis=0)[:3]
+ self.__bounds[1, :positions.shape[1]] = \
+ (positions + radius).max(axis=0)[:3]
+ return self.__bounds.copy()
+
+
+# Meshes ######################################################################
+
+class Mesh3D(Geometry):
+ """A conventional 3D mesh"""
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ //uniform mat3 matrixInvTranspose;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ //vNormal = matrixInvTranspose * normalize(normal);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+ gl_Position = matrix * vec4(position, 1.0);
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ $clippingDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $clippingCall(vCameraPosition);
+
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+ }
+ """))
+
+ def __init__(self,
+ positions,
+ colors,
+ normals=None,
+ mode='triangles',
+ indices=None):
+ assert mode in self._TRIANGLE_MODES
+ super(Mesh3D, self).__init__(mode, indices,
+ position=positions,
+ normal=normals,
+ color=colors)
+
+ self._culling = None
+
+ @property
+ def culling(self):
+ """Face culling (str)
+
+ One of 'back', 'front' or None.
+ """
+ return self._culling
+
+ @culling.setter
+ def culling(self, culling):
+ assert culling in ('back', 'front', None)
+ if culling != self._culling:
+ self._culling = culling
+ self.notify()
+
+ def renderGL2(self, ctx):
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fragLightFunction = ctx.viewport.light.fragmentDef
+ else:
+ fragLightFunction = ctx.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ clippingDecl=ctx.clipper.fragDecl,
+ clippingCall=ctx.clipper.fragCall,
+ lightingFunction=fragLightFunction,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ if isnormals:
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ if self.culling is not None:
+ cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK
+ gl.glCullFace(cullFace)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.clipper.setupProgram(ctx, prog)
+
+ self._draw(prog)
+
+ if self.culling is not None:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+# Group ######################################################################
+
+# TODO lighting, clipping as groups?
+# group composition?
+
+class GroupDepthOffset(core.Group):
+ """A group using 2-pass rendering and glDepthRange to avoid Z-fighting"""
+
+ def __init__(self, children=(), epsilon=None):
+ super(GroupDepthOffset, self).__init__(children)
+ self._epsilon = epsilon
+ self.isDepthRangeOn = True
+
+ def prepareGL2(self, ctx):
+ if self._epsilon is None:
+ depthbits = gl.glGetInteger(gl.GL_DEPTH_BITS)
+ self._epsilon = 1. / (1 << (depthbits - 1))
+
+ def renderGL2(self, ctx):
+ if self.isDepthRangeOn:
+ self._renderGL2WithDepthRange(ctx)
+ else:
+ super(GroupDepthOffset, self).renderGL2(ctx)
+
+ def _renderGL2WithDepthRange(self, ctx):
+ # gl.glDepthFunc(gl.GL_LESS)
+ with gl.enabled(gl.GL_CULL_FACE):
+ gl.glCullFace(gl.GL_BACK)
+ for child in self.children:
+ gl.glColorMask(
+ gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(self._epsilon, 1.)
+
+ child.render(ctx)
+
+ gl.glColorMask(
+ gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_FALSE)
+ gl.glDepthRange(0., 1. - self._epsilon)
+
+ child.render(ctx)
+
+ gl.glCullFace(gl.GL_FRONT)
+ for child in reversed(self.children):
+ gl.glColorMask(
+ gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(self._epsilon, 1.)
+
+ child.render(ctx)
+
+ gl.glColorMask(
+ gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_FALSE)
+ gl.glDepthRange(0., 1. - self._epsilon)
+
+ child.render(ctx)
+
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(0., 1.)
+ # gl.glDepthFunc(gl.GL_LEQUAL)
+ # TODO use epsilon for all rendering?
+ # TODO issue with picking in depth buffer!
+
+
+class GroupBBox(core.PrivateGroup):
+ """A group displaying a bounding box around the children."""
+
+ def __init__(self, children=(), color=(1., 1., 1., 1.)):
+ super(GroupBBox, self).__init__()
+ self._group = core.Group(children)
+
+ self._boxTransforms = transform.TransformList(
+ (transform.Translate(), transform.Scale()))
+
+ self._boxWithAxes = BoxWithAxes(color)
+ self._boxWithAxes.smooth = False
+ self._boxWithAxes.transforms = self._boxTransforms
+
+ self._children = [self._boxWithAxes, self._group]
+
+ def _updateBoxAndAxes(self):
+ """Update bbox and axes position and size according to children."""
+ bounds = self._group.bounds(dataBounds=True)
+ if bounds is not None:
+ origin = bounds[0]
+ scale = [(d if d != 0. else 1.) for d in bounds[1] - bounds[0]]
+ else:
+ origin, scale = (0., 0., 0.), (1., 1., 1.)
+
+ self._boxTransforms[0].translation = origin
+ self._boxTransforms[1].scale = scale
+
+ def _bounds(self, dataBounds=False):
+ self._updateBoxAndAxes()
+ return super(GroupBBox, self)._bounds(dataBounds)
+
+ def prepareGL2(self, ctx):
+ self._updateBoxAndAxes()
+ super(GroupBBox, self).prepareGL2(ctx)
+
+ # Give access to _group children
+
+ @property
+ def children(self):
+ return self._group.children
+
+ @children.setter
+ def children(self, iterable):
+ self._group.children = iterable
+
+ # Give access to box color
+
+ @property
+ def color(self):
+ """The RGBA color to use for the box: 4 float in [0, 1]"""
+ return self._boxWithAxes.color
+
+ @color.setter
+ def color(self, color):
+ self._boxWithAxes.color = color
+
+
+# Clipping Plane ##############################################################
+
+class ClipPlane(PlaneInGroup):
+ """A clipping plane attached to a box"""
+
+ def renderGL2(self, ctx):
+ super(ClipPlane, self).renderGL2(ctx)
+
+ if self.visible:
+ # Set-up clipping plane for following brothers
+
+ # No need of perspective divide, no projection
+ point = ctx.objectToCamera.transformPoint(self.plane.point,
+ perspectiveDivide=False)
+ normal = ctx.objectToCamera.transformNormal(self.plane.normal)
+ ctx.setClipPlane(point, normal)
+
+ def postRender(self, ctx):
+ if self.visible:
+ # Disable clip planes
+ ctx.setClipPlane()
diff --git a/silx/gui/plot3d/scene/setup.py b/silx/gui/plot3d/scene/setup.py
new file mode 100644
index 0000000..ff4c0a6
--- /dev/null
+++ b/silx/gui/plot3d/scene/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('scene', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/plot3d/scene/test/__init__.py b/silx/gui/plot3d/scene/test/__init__.py
new file mode 100644
index 0000000..fc4621e
--- /dev/null
+++ b/silx/gui/plot3d/scene/test/__init__.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import unittest
+
+from .test_transform import suite as test_transform_suite
+from .test_utils import suite as test_utils_suite
+
+
+def suite():
+ testsuite = unittest.TestSuite()
+ testsuite.addTest(test_transform_suite())
+ testsuite.addTest(test_utils_suite())
+ return testsuite
diff --git a/silx/gui/plot3d/scene/test/test_transform.py b/silx/gui/plot3d/scene/test/test_transform.py
new file mode 100644
index 0000000..9ea0af1
--- /dev/null
+++ b/silx/gui/plot3d/scene/test/test_transform.py
@@ -0,0 +1,91 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/01/2017"
+
+
+import numpy
+import unittest
+
+from silx.gui.plot3d.scene import transform
+
+
+class TestTransformList(unittest.TestCase):
+
+ def assertSameArrays(self, a, b):
+ return self.assertTrue(numpy.allclose(a, b, atol=1e-06))
+
+ def testTransformList(self):
+ """Minimalistic test of TransformList"""
+ transforms = transform.TransformList()
+ refmatrix = numpy.identity(4, dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Append translate
+ transforms.append(transform.Translate(1., 1., 1.))
+ refmatrix = numpy.array(((1., 0., 0., 1.),
+ (0., 1., 0., 1.),
+ (0., 0., 1., 1.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Extend scale
+ transforms.extend([transform.Scale(0.1, 2., 1.)])
+ refmatrix = numpy.dot(refmatrix,
+ numpy.array(((0.1, 0., 0., 0.),
+ (0., 2., 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)),
+ dtype=numpy.float32))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Insert rotate
+ transforms.insert(0, transform.Rotate(360.))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Update translate and check for listener called
+ self._callCount = 0
+
+ def listener(source):
+ self._callCount += 1
+ transforms.addListener(listener)
+
+ transforms[1].tx += 1
+ self.assertEqual(self._callCount, 1)
+
+
+def suite():
+ testsuite = unittest.TestSuite()
+ testsuite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestTransformList))
+ return testsuite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/scene/test/test_utils.py b/silx/gui/plot3d/scene/test/test_utils.py
new file mode 100644
index 0000000..65c2407
--- /dev/null
+++ b/silx/gui/plot3d/scene/test/test_utils.py
@@ -0,0 +1,275 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import unittest
+from silx.test.utils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot3d.scene import utils
+
+
+# angleBetweenVectors #########################################################
+
+class TestAngleBetweenVectors(ParametricTestCase):
+
+ TESTS = { # name: (refvector, vectors, norm, refangles)
+ 'single vector':
+ ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.),
+ 'single vector, no norm':
+ ((1., 0., 0.), (1., 0., 0.), None, 0.),
+
+ 'with orthogonal norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (0., 0., 1.),
+ (0., 90., 180., 270.)),
+
+ 'with coplanar norm': # = similar to no norm
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (1., 0., 0.),
+ (0., 90., 180., 90.)),
+
+ 'without norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ None,
+ (0., 90., 180., 90.)),
+
+ 'not unit vectors':
+ ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)),
+ }
+
+ def testAngleBetweenVectorsFunction(self):
+ for name, params in self.TESTS.items():
+ refvector, vectors, norm, refangles = params
+ with self.subTest(name):
+ refangles = numpy.radians(refangles)
+
+ refvector = numpy.array(refvector)
+ vectors = numpy.array(vectors)
+ if norm is not None:
+ norm = numpy.array(norm)
+
+ testangles = utils.angleBetweenVectors(
+ refvector, vectors, norm)
+
+ self.assertTrue(
+ numpy.allclose(testangles, refangles, atol=1e-5))
+
+
+# Plane #######################################################################
+
+class AssertNotificationContext(object):
+ """Context that checks if an event.Notifier is sending events."""
+
+ def __init__(self, notifier, count=1):
+ """Initializer.
+
+ :param event.Notifier notifier: The notifier to test.
+ :param int count: The expected number of calls.
+ """
+ self._notifier = notifier
+ self._callCount = None
+ self._count = count
+
+ def __enter__(self):
+ self._callCount = 0
+ self._notifier.addListener(self._callback)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ # Do not return True so exceptions are propagated
+ self._notifier.removeListener(self._callback)
+ assert self._callCount == self._count
+ self._callCount = None
+
+ def _callback(self, *args, **kwargs):
+ self._callCount += 1
+
+
+class TestPlaneParameters(ParametricTestCase):
+ """Test Plane.parameters read/write and notifications."""
+
+ PARAMETERS = {
+ 'unit normal': (1., 0., 0., 1.),
+ 'not unit normal': (1., 1., 0., 1.),
+ 'd = 0': (1., 0., 0., 0.)
+ }
+
+ def testParameters(self):
+ """Check parameters read/write and notification."""
+ plane = utils.Plane()
+
+ for name, parameters in self.PARAMETERS.items():
+ with self.subTest(name, parameters=parameters):
+ with AssertNotificationContext(plane):
+ plane.parameters = parameters
+
+ # Plane parameters are converted to have a unit normal
+ normparams = parameters / numpy.linalg.norm(parameters[:3])
+ self.assertTrue(numpy.allclose(plane.parameters, normparams))
+
+ ZEROS_PARAMETERS = (
+ (0., 0., 0., 0.),
+ (0., 0., 0., 1.)
+ )
+
+ ZEROS = 0., 0., 0., 0.
+
+ def testParametersNoPlane(self):
+ """Test Plane.parameters with ||normal|| == 0 ."""
+ plane = utils.Plane()
+ plane.parameters = self.ZEROS
+
+ for parameters in self.ZEROS_PARAMETERS:
+ with self.subTest(parameters=parameters):
+ with AssertNotificationContext(plane, count=0):
+ plane.parameters = parameters
+ self.assertTrue(
+ numpy.allclose(plane.parameters, self.ZEROS, 0., 0.))
+
+
+# unindexArrays ###############################################################
+
+class TestUnindexArrays(ParametricTestCase):
+ """Test unindexArrays function."""
+
+ def testBasicModes(self):
+ """Test for modes: points, lines and triangles"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ refresults = (numpy.array((1., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (0, 0))))
+
+ for mode in ('points', 'lines', 'triangles'):
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedLines(self):
+ """Test for modes: line_strip, loop"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ results = {
+ 'line_strip': (
+ numpy.array((1., 2., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))),
+ 'loop': (
+ numpy.array((1., 2., 2., 0., 0., 1.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedTriangles(self):
+ """Test for modes: triangle_strip, fan"""
+ indices = numpy.array((1, 2, 0, 3))
+ arrays = (numpy.array((0., 1., 2., 3.)),
+ numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))))
+ results = {
+ 'triangle_strip': (
+ numpy.array((1., 2., 0., 2., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))),
+ 'fan': (
+ numpy.array((1., 2., 0., 1., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testBadIndices(self):
+ """Test with negative indices and indices higher than array length"""
+ arrays = numpy.array((0, 1)), numpy.array((0, 1, 2))
+
+ # negative indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (-1, 0), *arrays)
+
+ # Too high indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (0, 10), *arrays)
+
+
+# triangleNormals #############################################################
+
+class TestTriangleNormals(ParametricTestCase):
+ """Test triangleNormals function."""
+
+ def test(self):
+ """Test for modes: points, lines and triangles"""
+ positions = numpy.array(
+ ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z
+ (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle
+ # Degenerated triangles:
+ (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points
+ (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point
+ ),
+ dtype='float32')
+
+ normals = numpy.array(
+ ((0., 0., 1.),
+ (-0.40824829, 0.81649658, -0.40824829),
+ (0., 0., 0.),
+ (0., 0., 0.)),
+ dtype='float32')
+
+ testnormals = utils.trianglesNormal(positions)
+ self.assertTrue(numpy.allclose(testnormals, normals))
+
+
+# suite #######################################################################
+
+def suite():
+ testsuite = unittest.TestSuite()
+ for test in (TestAngleBetweenVectors,
+ TestPlaneParameters,
+ TestUnindexArrays,
+ TestTriangleNormals):
+ testsuite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(test))
+ return testsuite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/scene/text.py b/silx/gui/plot3d/scene/text.py
new file mode 100644
index 0000000..903fc21
--- /dev/null
+++ b/silx/gui/plot3d/scene/text.py
@@ -0,0 +1,534 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Primitive displaying a text field in the scene."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/10/2016"
+
+
+import logging
+import numpy
+
+from silx.gui.plot.Colors import rgba
+
+from ... import _glutils
+from ..._glutils import gl
+
+from ..._glutils import font as _font
+from ...plot._utils import ticklayout
+
+from . import event, primitives, core, transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Font(event.Notifier):
+ """Description of a font.
+
+ :param str name: Family of the font
+ :param int size: Size of the font in points
+ :param int weight: Font weight
+ :param bool italic: True for italic font, False (default) otherwise
+ """
+
+ def __init__(self, name=None, size=-1, weight=-1, italic=False):
+ self._name = name if name is not None else _font.getDefaultFontFamily()
+ self._size = size
+ self._weight = weight
+ self._italic = italic
+ super(Font, self).__init__()
+
+ name = event.notifyProperty(
+ '_name',
+ doc="""Name of the font (str)""",
+ converter=str)
+
+ size = event.notifyProperty(
+ '_size',
+ doc="""Font size in points (int)""",
+ converter=int)
+
+ weight = event.notifyProperty(
+ '_weight',
+ doc="""Font size in points (int)""",
+ converter=int)
+
+ italic = event.notifyProperty(
+ '_italic',
+ doc="""True for italic (bool)""",
+ converter=bool)
+
+
+class Text2D(primitives.Geometry):
+ """Text field as a 2D texture displayed with bill-boarding
+
+ :param str text: Text to display
+ :param Font font: The font to use
+ """
+
+ # Text anchor values
+ CENTER = 'center'
+
+ LEFT = 'left'
+ RIGHT = 'right'
+
+ TOP = 'top'
+ BASELINE = 'baseline'
+ BOTTOM = 'bottom'
+
+ _ALIGN = LEFT, CENTER, RIGHT
+ _VALIGN = TOP, BASELINE, CENTER, BOTTOM
+
+ _rasterTextCache = {}
+ """Internal cache storing already rasterized text"""
+ # TODO limit cache size and discard least recent used
+
+ def __init__(self, text='', font=None):
+ self._dirtyTexture = True
+ self._dirtyAlign = True
+ self._baselineOffset = 0
+ self._text = text
+ self._font = font if font is not None else Font()
+ self._foreground = 1., 1., 1., 1.
+ self._background = 0., 0., 0., 0.
+ self._overlay = False
+ self._align = 'left'
+ self._valign = 'baseline'
+ self._devicePixelRatio = 1.0 # Store it to check for changes
+
+ self._texture = None
+ self._textureDirty = True
+
+ super(Text2D, self).__init__(
+ 'triangle_strip',
+ copy=False,
+ # Keep an array for position as it is bound to attr 0 and MUST
+ # be active and an array at least on Mac OS X
+ position=numpy.zeros((4, 3), dtype=numpy.float32),
+ vertexID=numpy.arange(4., dtype=numpy.float32).reshape(4, 1),
+ offsetInViewportCoords=(0., 0.))
+
+ @property
+ def text(self):
+ """Text displayed by this primitive (str)"""
+ return self._text
+
+ @text.setter
+ def text(self, text):
+ text = str(text)
+ if text != self._text:
+ self._dirtyTexture = True
+ self._text = text
+ self.notify()
+
+ @property
+ def font(self):
+ """Font to use to raster text (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font.removeListener(self._fontChanged)
+ self._font = font
+ self._font.addListener(self._fontChanged)
+ self._fontChanged(self) # Which calls notify and primitive as dirty
+
+ def _fontChanged(self, source):
+ """Listen for font change"""
+ self._dirtyTexture = True
+ self.notify()
+
+ foreground = event.notifyProperty(
+ '_foreground', doc="""RGBA color of the text: 4 float in [0, 1]""",
+ converter=rgba)
+
+ background = event.notifyProperty(
+ '_background',
+ doc="RGBA background color of the text field: 4 float in [0, 1]",
+ converter=rgba)
+
+ overlay = event.notifyProperty(
+ '_overlay',
+ doc="True to always display text on top of the scene (default: False)",
+ converter=bool)
+
+ def _setAlign(self, align):
+ assert align in self._ALIGN
+ self._align = align
+ self._dirtyAlign = True
+ self.notify()
+
+ align = property(
+ lambda self: self._align,
+ _setAlign,
+ doc="""Horizontal anchor position of the text field (str).
+
+ Either 'left' (default), 'center' or 'right'.""")
+
+ def _setVAlign(self, valign):
+ assert valign in self._VALIGN
+ self._valign = valign
+ self._dirtyAlign = True
+ self.notify()
+
+ valign = property(
+ lambda self: self._valign,
+ _setVAlign,
+ doc="""Vertical anchor position of the text field (str).
+
+ Either 'top', 'baseline' (default), 'center' or 'bottom'""")
+
+ def _raster(self, devicePixelRatio):
+ """Raster current primitive to a bitmap
+
+ :param float devicePixelRatio:
+ The ratio between device and device-independent pixels
+ :return: Corresponding image in grayscale and baseline offset from top
+ :rtype: (HxW numpy.ndarray of uint8, int)
+ """
+ params = (self.text,
+ self.font.name,
+ self.font.size,
+ self.font.weight,
+ self.font.italic,
+ devicePixelRatio)
+
+ if params not in self._rasterTextCache: # Add to cache
+ self._rasterTextCache[params] = _font.rasterText(*params)
+
+ array, offset = self._rasterTextCache[params]
+ return array.copy(), offset
+
+ def _bounds(self, dataBounds=False):
+ return None
+
+ def prepareGL2(self, context):
+ # Check if devicePixelRatio has changed since last rendering
+ devicePixelRatio = context.glCtx.devicePixelRatio
+ if self._devicePixelRatio != devicePixelRatio:
+ self._devicePixelRatio = devicePixelRatio
+ self._dirtyTexture = True
+
+ if self._dirtyTexture:
+ self._dirtyTexture = False
+
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._baselineOffset = 0
+
+ if self.text:
+ image, self._baselineOffset = self._raster(
+ self._devicePixelRatio)
+ self._texture = _glutils.Texture(
+ gl.GL_R8, image, gl.GL_RED,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+ self._dirtyAlign = True # To force update of offset
+
+ if self._dirtyAlign:
+ self._dirtyAlign = False
+
+ if self._texture is not None:
+ height, width = self._texture.shape
+
+ if self._align == 'left':
+ ox = 0.
+ elif self._align == 'center':
+ ox = - width // 2
+ elif self._align == 'right':
+ ox = - width
+ else:
+ _logger.error("Unsupported align: %s", self._align)
+ ox = 0.
+
+ if self._valign == 'top':
+ oy = 0.
+ elif self._valign == 'baseline':
+ oy = self._baselineOffset
+ elif self._valign == 'center':
+ oy = height // 2
+ elif self._valign == 'bottom':
+ oy = height
+ else:
+ _logger.error("Unsupported valign: %s", self._valign)
+ oy = 0.
+
+ offsets = (ox, oy) + numpy.array(
+ ((0., 0.), (width, 0.), (0., -height), (width, -height)),
+ dtype=numpy.float32)
+ self.setAttribute('offsetInViewportCoords', offsets)
+
+ super(Text2D, self).prepareGL2(context)
+
+ def renderGL2(self, context):
+ if not self.text:
+ return # Nothing to render
+
+ program = context.glCtx.prog(*self._shaders)
+ program.use()
+
+ program.setUniformMatrix('matrix', context.objectToNDC.matrix)
+ gl.glUniform2f(
+ program.uniforms['viewportSize'], *context.viewport.size)
+ gl.glUniform4f(program.uniforms['foreground'], *self.foreground)
+ gl.glUniform4f(program.uniforms['background'], *self.background)
+ gl.glUniform1i(program.uniforms['texture'], self._texture.texUnit)
+ gl.glUniform1i(program.uniforms['isOverlay'],
+ 1 if self._overlay else 0)
+
+ self._texture.bind()
+
+ if not self._overlay or not gl.glGetBoolean(gl.GL_DEPTH_TEST):
+ self._draw(program)
+ else: # overlay and depth test currently enabled
+ gl.glDisable(gl.GL_DEPTH_TEST)
+ self._draw(program)
+ gl.glEnable(gl.GL_DEPTH_TEST)
+
+ # TODO texture atlas + viewportSize as attribute to chain text rendering
+
+ _shaders = (
+ """
+ attribute vec3 position;
+ attribute vec2 offsetInViewportCoords; /* Offset in pixels (y upward) */
+ attribute float vertexID; /* Index of rectangle corner */
+
+ uniform mat4 matrix;
+ uniform vec2 viewportSize; /* Width, height of the viewport */
+ uniform int isOverlay;
+
+ varying vec2 texCoords;
+
+ void main(void)
+ {
+ vec4 clipPos = matrix * vec4(position, 1.0);
+ vec4 ndcPos = clipPos / clipPos.w; /* Perspective divide */
+
+ /* Align ndcPos with pixels in viewport-like coords (origin useless) */
+ vec2 viewportPos = floor((ndcPos.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize);
+
+ /* Apply offset in viewport coords */
+ viewportPos += offsetInViewportCoords;
+
+ /* Convert back to NDC */
+ vec2 pointPos = 2.0 * viewportPos / viewportSize - vec2(1.0, 1.0);
+ float z = (isOverlay != 0) ? -1.0 : ndcPos.z;
+ gl_Position = vec4(pointPos, z, 1.0);
+
+ /* Index : texCoords:
+ * 0: (0., 0.)
+ * 1: (1., 0.)
+ * 2: (0., 1.)
+ * 3: (1., 1.)
+ */
+ texCoords = vec2(vertexID == 0.0 || vertexID == 2.0 ? 0.0 : 1.0,
+ vertexID < 1.5 ? 0.0 : 1.0);
+ }
+ """, # noqa
+
+ """
+ varying vec2 texCoords;
+
+ uniform vec4 foreground;
+ uniform vec4 background;
+ uniform sampler2D texture;
+
+ void main(void)
+ {
+ float value = texture2D(texture, texCoords).r;
+
+ if (background.a != 0.0) {
+ gl_FragColor = mix(background, foreground, value);
+ } else {
+ gl_FragColor = foreground;
+ gl_FragColor.a *= value;
+ if (gl_FragColor.a <= 0.01) {
+ discard;
+ }
+ }
+ }
+ """)
+
+
+class LabelledAxes(primitives.GroupBBox):
+ """A group displaying a bounding box with axes labels around its children.
+ """
+
+ def __init__(self):
+ super(LabelledAxes, self).__init__()
+ self._ticksForBounds = None
+
+ self._font = Font()
+
+ # TODO offset labels from anchor in pixels
+
+ self._xlabel = Text2D(font=self._font)
+ self._xlabel.align = 'center'
+ self._xlabel.transforms = [self._boxTransforms,
+ transform.Translate(tx=0.5)]
+ self._children.append(self._xlabel)
+
+ self._ylabel = Text2D(font=self._font)
+ self._ylabel.align = 'center'
+ self._ylabel.transforms = [self._boxTransforms,
+ transform.Translate(ty=0.5)]
+ self._children.append(self._ylabel)
+
+ self._zlabel = Text2D(font=self._font)
+ self._zlabel.align = 'center'
+ self._zlabel.transforms = [self._boxTransforms,
+ transform.Translate(tz=0.5)]
+ self._children.append(self._zlabel)
+
+ self._tickLines = primitives.Lines( # Init tick lines with dummy pos
+ positions=((0., 0., 0.), (0., 0., 0.)),
+ mode='lines')
+ self._tickLines.visible = False
+ self._children.append(self._tickLines)
+
+ self._tickLabels = core.Group()
+ self._children.append(self._tickLabels)
+
+ @property
+ def font(self):
+ """Font of axes text labels (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font = font
+ self._xlabel.font = font
+ self._ylabel.font = font
+ self._zlabel.font = font
+ for label in self._tickLabels.children:
+ label.font = font
+
+ @property
+ def xlabel(self):
+ """Text label of the X axis (str)"""
+ return self._xlabel.text
+
+ @xlabel.setter
+ def xlabel(self, text):
+ self._xlabel.text = text
+
+ @property
+ def ylabel(self):
+ """Text label of the Y axis (str)"""
+ return self._ylabel.text
+
+ @ylabel.setter
+ def ylabel(self, text):
+ self._ylabel.text = text
+
+ @property
+ def zlabel(self):
+ """Text label of the Z axis (str)"""
+ return self._zlabel.text
+
+ @zlabel.setter
+ def zlabel(self, text):
+ self._zlabel.text = text
+
+ def _updateTicks(self):
+ """Check if ticks need update and update them if needed."""
+ bounds = self._group.bounds(transformed=False, dataBounds=True)
+ if bounds is None: # No content
+ if self._ticksForBounds is not None:
+ self._ticksForBounds = None
+ self._tickLines.visible = False
+ self._tickLabels.children = [] # Reset previous labels
+
+ elif (self._ticksForBounds is None or
+ not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ self._ticksForBounds = bounds
+
+ # Update ticks
+ # TODO make ticks having a constant length on the screen
+ ticklength = numpy.abs(bounds[1] - bounds[0]) / 20.
+
+ xticks, xlabels = ticklayout.ticks(*bounds[:, 0])
+ yticks, ylabels = ticklayout.ticks(*bounds[:, 1])
+ zticks, zlabels = ticklayout.ticks(*bounds[:, 2])
+
+ # Update tick lines
+ coords = numpy.empty(
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
+ dtype=numpy.float32)
+ coords[:, :, :] = bounds[0, :] # account for offset from origin
+
+ xcoords = coords[:len(xticks)]
+ xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
+ xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
+ xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
+
+ ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
+ ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
+ ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
+
+ zcoords = coords[len(xticks) + len(yticks):]
+ zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
+ zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
+ zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
+
+ self._tickLines.setAttribute('position', coords.reshape(-1, 3))
+ self._tickLines.visible = True
+
+ # Update labels
+ offsets = bounds[0] - ticklength
+ labels = []
+ for tick, label in zip(xticks, xlabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=tick, ty=offsets[1], tz=offsets[2])]
+ labels.append(text)
+
+ for tick, label in zip(yticks, ylabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=offsets[0], ty=tick, tz=offsets[2])]
+ labels.append(text)
+
+ for tick, label in zip(zticks, zlabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=offsets[0], ty=offsets[1], tz=tick)]
+ labels.append(text)
+
+ self._tickLabels.children = labels # Reset previous labels
+
+ def prepareGL2(self, context):
+ self._updateTicks()
+ super(LabelledAxes, self).prepareGL2(context)
diff --git a/silx/gui/plot3d/scene/transform.py b/silx/gui/plot3d/scene/transform.py
new file mode 100644
index 0000000..71a6b74
--- /dev/null
+++ b/silx/gui/plot3d/scene/transform.py
@@ -0,0 +1,968 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 4x4 matrix operation and classes to handle them."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import itertools
+import numpy
+
+from . import event
+
+
+# Functions ###################################################################
+
+# Projections
+
+def mat4LookAtDir(position, direction, up):
+ """Creates matrix to look in direction from position.
+
+ :param position: Array-like 3 coordinates of the point of view position.
+ :param direction: Array-like 3 coordinates of the sight direction vector.
+ :param up: Array-like 3 coordinates of the upward direction
+ in the image plane.
+ :returns: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ assert len(position) == 3
+ assert len(direction) == 3
+ assert len(up) == 3
+
+ direction = numpy.array(direction, copy=True, dtype=numpy.float32)
+ dirnorm = numpy.linalg.norm(direction)
+ assert dirnorm != 0.
+ direction /= dirnorm
+
+ side = numpy.cross(direction,
+ numpy.array(up, copy=False, dtype=numpy.float32))
+ sidenorm = numpy.linalg.norm(side)
+ assert sidenorm != 0.
+ up = numpy.cross(side / sidenorm, direction)
+ upnorm = numpy.linalg.norm(up)
+ assert upnorm != 0.
+ up /= upnorm
+
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ matrix[0, :3] = side
+ matrix[1, :3] = up
+ matrix[2, :3] = -direction
+ return numpy.dot(matrix,
+ mat4Translate(-position[0], -position[1], -position[2]))
+
+
+def mat4LookAt(position, center, up):
+ """Creates matrix to look at center from position.
+
+ See gluLookAt.
+
+ :param position: Array-like 3 coordinates of the point of view position.
+ :param center: Array-like 3 coordinates of the center of the scene.
+ :param up: Array-like 3 coordinates of the upward direction
+ in the image plane.
+ :returns: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ position = numpy.array(position, copy=False, dtype=numpy.float32)
+ center = numpy.array(center, copy=False, dtype=numpy.float32)
+ direction = center - position
+ return mat4LookAtDir(position, direction, up)
+
+
+def mat4Frustum(left, right, bottom, top, near, far):
+ """Creates a frustum projection matrix.
+
+ See glFrustum.
+ """
+ return numpy.array((
+ (2.*near / (right-left), 0., (right+left) / (right-left), 0.),
+ (0., 2.*near / (top-bottom), (top+bottom) / (top-bottom), 0.),
+ (0., 0., -(far+near) / (far-near), -2.*far*near / (far-near)),
+ (0., 0., -1., 0.)), dtype=numpy.float32)
+
+
+def mat4Perspective(fovy, width, height, near, far):
+ """Creates a perspective projection matrix.
+
+ Similar to gluPerspective.
+
+ :param float fovy: Field of view angle in degrees in the y direction.
+ :param float width: Width of the viewport.
+ :param float height: Height of the viewport.
+ :param float near: Distance to the near plane (strictly positive).
+ :param float far: Distance to the far plane (strictly positive).
+ :return: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ assert fovy != 0
+ assert height != 0
+ assert width != 0
+ assert near > 0.
+ assert far > near
+ aspectratio = width / height
+ f = 1. / numpy.tan(numpy.radians(fovy) / 2.)
+ return numpy.array((
+ (f / aspectratio, 0., 0., 0.),
+ (0., f, 0., 0.),
+ (0., 0., (far + near) / (near - far), 2. * far * near / (near - far)),
+ (0., 0., -1., 0.)), dtype=numpy.float32)
+
+
+def mat4Orthographic(left, right, bottom, top, near, far):
+ """Creates an orthographic (i.e., parallel) projection matrix.
+
+ See glOrtho.
+ """
+ return numpy.array((
+ (2. / (right - left), 0., 0., - (right + left) / (right - left)),
+ (0., 2. / (top - bottom), 0., - (top + bottom) / (top - bottom)),
+ (0., 0., -2. / (far - near), - (far + near) / (far - near)),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+# Affine
+
+def mat4Translate(tx, ty, tz):
+ """4x4 translation matrix."""
+ return numpy.array((
+ (1., 0., 0., tx),
+ (0., 1., 0., ty),
+ (0., 0., 1., tz),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Scale(sx, sy, sz):
+ """4x4 scale matrix."""
+ return numpy.array((
+ (sx, 0., 0., 0.),
+ (0., sy, 0., 0.),
+ (0., 0., sz, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4RotateFromAngleAxis(angle, x=0., y=0., z=1.):
+ """4x4 rotation matrix from angle and axis.
+
+ :param float angle: The rotation angle in radians.
+ :param float x: The rotation vector x coordinate.
+ :param float y: The rotation vector y coordinate.
+ :param float z: The rotation vector z coordinate.
+ """
+ ca = numpy.cos(angle)
+ sa = numpy.sin(angle)
+ return numpy.array((
+ ((1.-ca) * x*x + ca, (1.-ca) * x*y - sa*z, (1.-ca) * x*z + sa*y, 0.),
+ ((1.-ca) * x*y + sa*z, (1.-ca) * y*y + ca, (1.-ca) * y*z - sa*x, 0.),
+ ((1.-ca) * x*z - sa*y, (1.-ca) * y*z + sa*x, (1.-ca) * z*z + ca, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4RotateFromQuaternion(quaternion):
+ """4x4 rotation matrix from quaternion.
+
+ :param quaternion: Array-like unit quaternion stored as (x, y, z, w)
+ """
+ quaternion = numpy.array(quaternion, copy=True)
+ quaternion /= numpy.linalg.norm(quaternion)
+
+ qx, qy, qz, qw = quaternion
+ return numpy.array((
+ (1. - 2.*(qy**2 + qz**2), 2.*(qx*qy - qw*qz), 2.*(qx*qz + qw*qy), 0.),
+ (2.*(qx*qy + qw*qz), 1. - 2.*(qx**2 + qz**2), 2.*(qy*qz - qw*qx), 0.),
+ (2.*(qx*qz - qw*qy), 2.*(qy*qz + qw*qx), 1. - 2.*(qx**2 + qy**2), 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Shear(axis, sx=0., sy=0., sz=0.):
+ """4x4 shear matrix: Skew two axes relative to a third fixed one.
+
+ shearFactor = tan(shearAngle)
+
+ :param str axis: The axis to keep constant and shear against.
+ In 'x', 'y', 'z'.
+ :param float sx: The shear factor for the X axis relative to axis.
+ :param float sy: The shear factor for the Y axis relative to axis.
+ :param float sz: The shear factor for the Z axis relative to axis.
+ """
+ assert axis in ('x', 'y', 'z')
+
+ matrix = numpy.identity(4, dtype=numpy.float32)
+
+ # Make the shear column
+ index = 'xyz'.find(axis)
+ shearcolumn = numpy.array((sx, sy, sz, 0.), dtype=numpy.float32)
+ shearcolumn[index] = 1.
+ matrix[:, index] = shearcolumn
+ return matrix
+
+
+# Transforms ##################################################################
+
+class Transform(event.Notifier):
+
+ def __init__(self, static=False):
+ """Base class for (row-major) 4x4 matrix transforms.
+
+ :param bool static: False (default) to reset cache when changed,
+ True for static matrices.
+ """
+ super(Transform, self).__init__()
+ self._matrix = None
+ self._inverse = None
+ if not static:
+ self.addListener(self._changed) # Listening self for changes
+
+ def __repr__(self):
+ return '%s(%s)' % (self.__class__.__init__,
+ repr(self.getMatrix(copy=False)))
+
+ def inverse(self):
+ """Return the Transform of the inverse.
+
+ The returned Transform is static, it is not updated when this
+ Transform is modified.
+
+ :return: A Transform which is the inverse of this Transform.
+ """
+ return Inverse(self)
+
+ # Matrix
+
+ def _makeMatrix(self):
+ """Override to build matrix"""
+ return numpy.identity(4, dtype=numpy.float32)
+
+ def _makeInverse(self):
+ """Override to build inverse matrix."""
+ return numpy.linalg.inv(self.getMatrix(copy=False))
+
+ def getMatrix(self, copy=True):
+ """The 4x4 matrix of this transform.
+
+ :param bool copy: True (the default) to get a copy of the matrix,
+ False to get the internal matrix, do not modify!
+ :return: 4x4 matrix of this transform.
+ """
+ if self._matrix is None:
+ self._matrix = self._makeMatrix()
+ if copy:
+ return self._matrix.copy()
+ else:
+ return self._matrix
+
+ matrix = property(getMatrix, doc="The 4x4 matrix of this transform.")
+
+ def getInverseMatrix(self, copy=False):
+ """The 4x4 matrix of the inverse of this transform.
+
+ :param bool copy: True (the default) to get a copy of the matrix,
+ False to get the internal matrix, do not modify!
+ :return: 4x4 matrix of the inverse of this transform.
+ """
+ if self._inverse is None:
+ self._inverse = self._makeInverse()
+ if copy:
+ return self._inverse.copy()
+ else:
+ return self._inverse
+
+ inverseMatrix = property(
+ getInverseMatrix,
+ doc="The 4x4 matrix of the inverse of this transform.")
+
+ # Listener
+
+ def _changed(self, source):
+ """Default self listener reseting matrix cache."""
+ self._matrix = None
+ self._inverse = None
+
+ # Multiplication with vectors
+
+ @staticmethod
+ def _prepareVector(vector, w):
+ """Add 4th coordinate (w) to vector if missing."""
+ assert len(vector) in (3, 4)
+ vector = numpy.array(vector, copy=False, dtype=numpy.float32)
+ if len(vector) == 3:
+ vector = numpy.append(vector, w)
+ return vector
+
+ def transformPoint(self, point, direct=True, perspectiveDivide=False):
+ """Apply the transform to a point.
+
+ If len(point) == 3, apply persective divide if possible.
+
+ :param point: Array-like vector of 3 or 4 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :param bool perspectiveDivide: Whether to apply the perspective divide
+ (True) or not (False, the default).
+ :return: The transformed point.
+ :rtype: numpy.ndarray of same length as point.
+ """
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+ result = numpy.dot(matrix, self._prepareVector(point, 1.))
+
+ if perspectiveDivide and result[3] != 0.:
+ result /= result[3]
+
+ if len(point) == 3:
+ return result[:3]
+ else:
+ return result
+
+ def transformDir(self, direction, direct=True):
+ """Apply the transform to a direction.
+
+ :param direction: Array-like vector of 3 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: The transformed direction.
+ :rtype: numpy.ndarray of length 3.
+ """
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+ return numpy.dot(matrix[:3, :3], direction[:3])
+
+ def transformNormal(self, normal, direct=True):
+ """Apply the transform to a normal: R = (M-1)t * V.
+
+ :param normal: Array-like vector of 3 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: The transformed normal.
+ :rtype: numpy.ndarray of length 3.
+ """
+ if direct:
+ matrix = self.getInverseMatrix(copy=False).T
+ else:
+ matrix = self.getMatrix(copy=False).T
+ return numpy.dot(matrix[:3, :3], normal[:3])
+
+ _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
+ dtype=numpy.float32)
+ """Unit cube corners used by :meth:`transformRectangularBox`"""
+
+ def transformBounds(self, bounds, direct=True):
+ """Apply the transform to an axes-aligned rectangular box.
+
+ :param bounds: Min and max coords of the box for each axes.
+ :type bounds: 2x3 numpy.ndarray
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: Axes-aligned rectangular box including the transformed box.
+ :rtype: 2x3 numpy.ndarray of float32
+ """
+ corners = numpy.ones((8, 4), dtype=numpy.float32)
+ corners[:, :3] = bounds[0] + \
+ self._CUBE_CORNERS * (bounds[1] - bounds[0])
+
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+
+ # Transform corners
+ cornerstransposed = numpy.dot(matrix, corners.T)
+ cornerstransposed = cornerstransposed / cornerstransposed[3]
+
+ # Get min/max for each axis
+ transformedbounds = numpy.empty((2, 3), dtype=numpy.float32)
+ transformedbounds[0] = cornerstransposed.T[:, :3].min(axis=0)
+ transformedbounds[1] = cornerstransposed.T[:, :3].max(axis=0)
+
+ return transformedbounds
+
+
+class Inverse(Transform):
+ """Transform which is the inverse of another one.
+
+ Static: It never gets updated.
+ """
+
+ def __init__(self, transform):
+ """Initializer.
+
+ :param Transform transform: The transform to invert.
+ """
+
+ super(Inverse, self).__init__(static=True)
+ self._matrix = transform.getInverseMatrix(copy=True)
+ self._inverse = transform.getMatrix(copy=True)
+
+
+class TransformList(Transform, event.HookList):
+ """List of transforms."""
+
+ def __init__(self, iterable=()):
+ Transform.__init__(self)
+ event.HookList.__init__(self, iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.removeListener(self._transformChanged)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.addListener(self._transformChanged)
+ self.notify()
+
+ def _transformChanged(self, source):
+ """Listen to transform changes of the list and its items."""
+ if source is not self: # Avoid infinite recursion
+ self.notify()
+
+ def _makeMatrix(self):
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ for transform in self:
+ matrix = numpy.dot(matrix, transform.getMatrix(copy=False))
+ return matrix
+
+
+class StaticTransformList(Transform):
+ """Transform that is a snapshot of a list of Transforms
+
+ It does not keep reference to the list of Transforms.
+
+ :param iterable: Iterable of Transform used for initialization
+ """
+
+ def __init__(self, iterable=()):
+ super(StaticTransformList, self).__init__(static=True)
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ for transform in iterable:
+ matrix = numpy.dot(matrix, transform.getMatrix(copy=False))
+ self._matrix = matrix # Init matrix once
+
+
+# Affine ######################################################################
+
+class Matrix(Transform):
+
+ def __init__(self, matrix=None):
+ """4x4 Matrix.
+
+ :param matrix: 4x4 array-like matrix or None for identity matrix.
+ """
+ super(Matrix, self).__init__(static=True)
+ self.setMatrix(matrix)
+
+ def setMatrix(self, matrix=None):
+ """Update the 4x4 Matrix.
+
+ :param matrix: 4x4 array-like matrix or None for identity matrix.
+ """
+ if matrix is None:
+ self._matrix = numpy.identity(4, dtype=numpy.float32)
+ else:
+ matrix = numpy.array(matrix, copy=True, dtype=numpy.float32)
+ assert matrix.shape == (4, 4)
+ self._matrix = matrix
+ # Reset cached inverse as Transform is declared static
+ self._inverse = None
+ self.notify()
+
+ # Redefined here to add a setter
+ matrix = property(Transform.getMatrix, setMatrix,
+ doc="The 4x4 matrix of this transform.")
+
+
+class Translate(Transform):
+ """4x4 translation matrix."""
+
+ def __init__(self, tx=0., ty=0., tz=0.):
+ super(Translate, self).__init__()
+ self._tx, self._ty, self._tz = 0., 0., 0.
+ self.setTranslate(tx, ty, tz)
+
+ def _makeMatrix(self):
+ return mat4Translate(self.tx, self.ty, self.tz)
+
+ def _makeInverse(self):
+ return mat4Translate(-self.tx, -self.ty, -self.tz)
+
+ @property
+ def tx(self):
+ return self._tx
+
+ @tx.setter
+ def tx(self, tx):
+ self.setTranslate(tx=tx)
+
+ @property
+ def ty(self):
+ return self._ty
+
+ @ty.setter
+ def ty(self, ty):
+ self.setTranslate(ty=ty)
+
+ @property
+ def tz(self):
+ return self._tz
+
+ @tz.setter
+ def tz(self, tz):
+ self.setTranslate(tz=tz)
+
+ @property
+ def translation(self):
+ return numpy.array((self.tx, self.ty, self.tz), dtype=numpy.float32)
+
+ @translation.setter
+ def translation(self, translations):
+ tx, ty, tz = translations
+ self.setTranslate(tx, ty, tz)
+
+ def setTranslate(self, tx=None, ty=None, tz=None):
+ if tx is not None:
+ self._tx = tx
+ if ty is not None:
+ self._ty = ty
+ if tz is not None:
+ self._tz = tz
+ self.notify()
+
+
+class Scale(Transform):
+ """4x4 scale matrix."""
+
+ def __init__(self, sx=1., sy=1., sz=1.):
+ super(Scale, self).__init__()
+ self._sx, self._sy, self._sz = 0., 0., 0.
+ self.setScale(sx, sy, sz)
+
+ def _makeMatrix(self):
+ return mat4Scale(self.sx, self.sy, self.sz)
+
+ def _makeInverse(self):
+ return mat4Scale(1. / self.sx, 1. / self.sy, 1. / self.sz)
+
+ @property
+ def sx(self):
+ return self._sx
+
+ @sx.setter
+ def sx(self, sx):
+ self.setScale(sx=sx)
+
+ @property
+ def sy(self):
+ return self._sy
+
+ @sy.setter
+ def sy(self, sy):
+ self.setScale(sy=sy)
+
+ @property
+ def sz(self):
+ return self._sz
+
+ @sz.setter
+ def sz(self, sz):
+ self.setScale(sz=sz)
+
+ @property
+ def scale(self):
+ return numpy.array((self._sx, self._sy, self._sz), dtype=numpy.float32)
+
+ @scale.setter
+ def scale(self, scales):
+ sx, sy, sz = scales
+ self.setScale(sx, sy, sz)
+
+ def setScale(self, sx=None, sy=None, sz=None):
+ if sx is not None:
+ assert sx != 0.
+ self._sx = sx
+ if sy is not None:
+ assert sy != 0.
+ self._sy = sy
+ if sz is not None:
+ assert sz != 0.
+ self._sz = sz
+ self.notify()
+
+
+class Rotate(Transform):
+
+ def __init__(self, angle=0., ax=0., ay=0., az=1.):
+ """4x4 rotation matrix.
+
+ :param float angle: The rotation angle in degrees.
+ :param float ax: The x coordinate of the rotation axis.
+ :param float ay: The y coordinate of the rotation axis.
+ :param float az: The z coordinate of the rotation axis.
+ """
+ super(Rotate, self).__init__()
+ self._angle = 0.
+ self._axis = None
+ self.setAngleAxis(angle, (ax, ay, az))
+
+ @property
+ def angle(self):
+ """The rotation angle in degrees."""
+ return self._angle
+
+ @angle.setter
+ def angle(self, angle):
+ self.setAngleAxis(angle=angle)
+
+ @property
+ def axis(self):
+ """The normalized rotation axis as a numpy.ndarray."""
+ return self._axis.copy()
+
+ @axis.setter
+ def axis(self, axis):
+ self.setAngleAxis(axis=axis)
+
+ def setAngleAxis(self, angle=None, axis=None):
+ """Update the angle and/or axis of the rotation.
+
+ :param float angle: The rotation angle in degrees.
+ :param axis: Array-like axis vector (3 coordinates).
+ """
+ if angle is not None:
+ self._angle = angle
+ if axis is not None:
+ assert len(axis) == 3
+ axis = numpy.array(axis, copy=True, dtype=numpy.float32)
+ assert axis.size == 3
+ norm = numpy.linalg.norm(axis)
+ if norm == 0.: # No axis, set rotation angle to 0.
+ self._angle = 0.
+ self._axis = numpy.array((0., 0., 1.), dtype=numpy.float32)
+ else:
+ self._axis = axis / norm
+
+ if angle is not None or axis is not None:
+ self.notify()
+
+ @property
+ def quaternion(self):
+ """Rotation unit quaternion as (x, y, z, w).
+
+ Where: ||(x, y, z)|| = sin(angle/2), w = cos(angle/2).
+ """
+ if numpy.linalg.norm(self._axis) == 0.:
+ return numpy.array((0., 0., 0., 1.), dtype=numpy.float32)
+
+ else:
+ quaternion = numpy.empty((4,), dtype=numpy.float32)
+ halfangle = 0.5 * numpy.radians(self.angle)
+ quaternion[0:3] = numpy.sin(halfangle) * self._axis
+ quaternion[3] = numpy.cos(halfangle)
+ return quaternion
+
+ @quaternion.setter
+ def quaternion(self, quaternion):
+ assert len(quaternion) == 4
+
+ # Normalize quaternion
+ quaternion = numpy.array(quaternion, copy=True)
+ quaternion /= numpy.linalg.norm(quaternion)
+
+ # Get angle
+ sinhalfangle = numpy.linalg.norm(quaternion[0:3])
+ coshalfangle = quaternion[3]
+ angle = 2. * numpy.arctan2(sinhalfangle, coshalfangle)
+
+ # Axis will be normalized in setAngleAxis
+ self.setAngleAxis(numpy.degrees(angle), quaternion[0:3])
+
+ def _makeMatrix(self):
+ angle = numpy.radians(self.angle, dtype=numpy.float32)
+ return mat4RotateFromAngleAxis(angle, *self.axis)
+
+ def _makeInverse(self):
+ return numpy.array(self.getMatrix(copy=False).transpose(),
+ copy=True, order='C',
+ dtype=numpy.float32)
+
+
+class Shear(Transform):
+
+ def __init__(self, axis, sx=0., sy=0., sz=0.):
+ """4x4 shear/skew matrix of 2 axes relative to the third one.
+
+ :param str axis: The axis to keep fixed, in 'x', 'y', 'z'
+ :param float sx: The shear factor for the x axis.
+ :param float sy: The shear factor for the y axis.
+ :param float sz: The shear factor for the z axis.
+ """
+ assert axis in ('x', 'y', 'z')
+ super(Shear, self).__init__()
+ self._axis = axis
+ self._factors = sx, sy, sz
+
+ @property
+ def axis(self):
+ """The axis against which other axes are skewed."""
+ return self._axis
+
+ @property
+ def factors(self):
+ """The shear factors: shearFactor = tan(shearAngle)"""
+ return self._factors
+
+ def _makeMatrix(self):
+ return mat4Shear(self.axis, *self.factors)
+
+ def _makeInverse(self):
+ sx, sy, sz = self.factors
+ return mat4Shear(self.axis, -sx, -sy, -sz)
+
+
+# Projection ##################################################################
+
+class _Projection(Transform):
+ """Base class for projection matrix.
+
+ Handles near and far clipping plane values.
+ Subclasses must implement :meth:`_makeMatrix`.
+
+ :param float near: Distance to the near plane.
+ :param float far: Distance to the far plane.
+ :param bool checkDepthExtent: Toggle checks near > 0 and far > near.
+ :param size: Viewport's size used to compute the aspect ratio.
+ :type size: 2-tuple of float (width, height).
+ """
+
+ def __init__(self, near, far, checkDepthExtent=False, size=(1., 1.)):
+ super(_Projection, self).__init__()
+ self._checkDepthExtent = checkDepthExtent
+ self._depthExtent = 1, 10
+ self.setDepthExtent(near, far) # set _depthExtent
+ self._size = 1., 1.
+ self.size = size # set _size
+
+ def setDepthExtent(self, near=None, far=None):
+ """Set the extent of the visible area along the viewing direction.
+
+ :param float near: The near clipping plane Z coord.
+ :param float far: The far clipping plane Z coord.
+ """
+ near = float(near) if near is not None else self._depthExtent[0]
+ far = float(far) if far is not None else self._depthExtent[1]
+
+ if self._checkDepthExtent:
+ assert near > 0.
+ assert far > near
+
+ self._depthExtent = near, far
+ self.notify()
+
+ @property
+ def near(self):
+ """Distance to the near plane."""
+ return self._depthExtent[0]
+
+ @near.setter
+ def near(self, near):
+ if near != self.near:
+ self.setDepthExtent(near=near)
+
+ @property
+ def far(self):
+ """Distance to the far plane."""
+ return self._depthExtent[1]
+
+ @far.setter
+ def far(self, far):
+ if far != self.far:
+ self.setDepthExtent(far=far)
+
+ @property
+ def size(self):
+ """Viewport size as a 2-tuple of float (width, height)."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ self._size = tuple(size)
+ self.notify()
+
+
+class Orthographic(_Projection):
+ """Orthographic (i.e., parallel) projection which keeps aspect ratio.
+
+ Clipping planes are adjusted to match the aspect ratio of
+ the :attr:`size` attribute.
+
+ The left, right, bottom and top parameters defines the area which must
+ always remain visible.
+ Effective clipping planes are adjusted to keep the aspect ratio.
+
+ :param float left: Coord of the left clipping plane.
+ :param float right: Coord of the right clipping plane.
+ :param float bottom: Coord of the bottom clipping plane.
+ :param float top: Coord of the top clipping plane.
+ :param float near: Distance to the near plane.
+ :param float far: Distance to the far plane.
+ :param size: Viewport's size used to compute the aspect ratio.
+ :type size: 2-tuple of float (width, height).
+ """
+
+ def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1.,
+ size=(1., 1.)):
+ self._left, self._right = left, right
+ self._bottom, self._top = bottom, top
+ super(Orthographic, self).__init__(near, far, checkDepthExtent=False,
+ size=size)
+ # _update called when setting size
+
+ def _makeMatrix(self):
+ return mat4Orthographic(
+ self.left, self.right, self.bottom, self.top, self.near, self.far)
+
+ def _update(self, left, right, bottom, top):
+ width, height = self.size
+ aspect = width / height
+
+ orthoaspect = abs(left - right) / abs(bottom - top)
+
+ if orthoaspect >= aspect: # Keep width, enlarge height
+ newheight = \
+ numpy.sign(top - bottom) * abs(left - right) / aspect
+ bottom = 0.5 * (bottom + top) - 0.5 * newheight
+ top = bottom + newheight
+
+ else: # Keep height, enlarge width
+ newwidth = \
+ numpy.sign(right - left) * abs(bottom - top) * aspect
+ left = 0.5 * (left + right) - 0.5 * newwidth
+ right = left + newwidth
+
+ # Store values
+ self._left, self._right = left, right
+ self._bottom, self._top = bottom, top
+
+ def setClipping(self, left=None, right=None, bottom=None, top=None):
+ """Set the clipping planes of the projection.
+
+ Parameters are adjusted to keep aspect ratio.
+ If a clipping plane coord is not provided, it uses its current value
+
+ :param float left: Coord of the left clipping plane.
+ :param float right: Coord of the right clipping plane.
+ :param float bottom: Coord of the bottom clipping plane.
+ :param float top: Coord of the top clipping plane.
+ """
+ left = float(left) if left is not None else self.left
+ right = float(right) if right is not None else self.right
+ bottom = float(bottom) if bottom is not None else self.bottom
+ top = float(top) if top is not None else self.top
+
+ self._update(left, right, bottom, top)
+ self.notify()
+
+ left = property(lambda self: self._left,
+ doc="Coord of the left clipping plane.")
+
+ right = property(lambda self: self._right,
+ doc="Coord of the right clipping plane.")
+
+ bottom = property(lambda self: self._bottom,
+ doc="Coord of the bottom clipping plane.")
+
+ top = property(lambda self: self._top,
+ doc="Coord of the top clipping plane.")
+
+ @property
+ def size(self):
+ """Viewport size as a 2-tuple of float (width, height) or None."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ self._size = float(size[0]), float(size[1])
+ self._update(self.left, self.right, self.bottom, self.top)
+ self.notify()
+
+
+class Ortho2DWidget(_Projection):
+ """Orthographic projection with pixel as unit.
+
+ Provides same coordinates as widgets:
+ origin: top left, X axis goes left, Y axis goes down.
+
+ :param float near: Z coordinate of the near clipping plane.
+ :param float far: Z coordinante of the far clipping plane.
+ :param size: Viewport's size used to compute the aspect ratio.
+ :type size: 2-tuple of float (width, height).
+ """
+
+ def __init__(self, near=-1., far=1., size=(1., 1.)):
+
+ super(Ortho2DWidget, self).__init__(near, far, size)
+
+ def _makeMatrix(self):
+ width, height = self.size
+ return mat4Orthographic(0., width, height, 0., self.near, self.far)
+
+
+class Perspective(_Projection):
+ """Perspective projection matrix defined by FOV and aspect ratio.
+
+ :param float fovy: Vertical field-of-view in degrees.
+ :param float near: The near clipping plane Z coord (stricly positive).
+ :param float far: The far clipping plane Z coord (> near).
+ :param size: Viewport's size used to compute the aspect ratio.
+ :type size: 2-tuple of float (width, height).
+ """
+
+ def __init__(self, fovy=90., near=0.1, far=1., size=(1., 1.)):
+
+ super(Perspective, self).__init__(near, far, checkDepthExtent=True)
+ self._fovy = 90.
+ self.fovy = fovy # Set _fovy
+ self.size = size # Set _ size
+
+ def _makeMatrix(self):
+ width, height = self.size
+ return mat4Perspective(self.fovy, width, height, self.near, self.far)
+
+ @property
+ def fovy(self):
+ """Vertical field-of-view in degrees."""
+ return self._fovy
+
+ @fovy.setter
+ def fovy(self, fovy):
+ self._fovy = float(fovy)
+ self.notify()
diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py
new file mode 100644
index 0000000..930a087
--- /dev/null
+++ b/silx/gui/plot3d/scene/utils.py
@@ -0,0 +1,516 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides functions to generate indices, to check intersection
+and to handle planes.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+import numpy
+
+from . import event
+
+
+_logger = logging.getLogger(__name__)
+
+
+# numpy #######################################################################
+
+def _uniqueAlongLastAxis(a):
+ """Numpy unique on the last axis of a 2D array
+
+ Implemented here as not in numpy as of writing.
+
+ See adding axis parameter to numpy.unique:
+ https://github.com/numpy/numpy/pull/3584/files#r6225452
+
+ :param array_like a: Input array.
+ :return: Unique elements along the last axis.
+ :rtype: numpy.ndarray
+ """
+ assert len(a.shape) == 2
+
+ # Construct a type over last array dimension to run unique on a 1D array
+ if a.dtype.char in numpy.typecodes['AllInteger']:
+ # Bit-wise comparison of the 2 indices of a line at once
+ # Expect a C contiguous array of shape N, 2
+ uniquedt = numpy.dtype((numpy.void, a.itemsize * a.shape[-1]))
+ elif a.dtype.char in numpy.typecodes['Float']:
+ uniquedt = [('f{i}'.format(i=i), a.dtype) for i in range(a.shape[-1])]
+ else:
+ raise TypeError("Unsupported type {dtype}".format(dtype=a.dtype))
+
+ uniquearray = numpy.unique(numpy.ascontiguousarray(a).view(uniquedt))
+ return uniquearray.view(a.dtype).reshape((-1, a.shape[-1]))
+
+
+# conversions #################################################################
+
+def triangleToLineIndices(triangleIndices, unicity=False):
+ """Generates lines indices from triangle indices.
+
+ This is generating lines indices for the edges of the triangles.
+
+ :param triangleIndices: The indices to draw a set of vertices as triangles.
+ :type triangleIndices: numpy.ndarray
+ :param bool unicity: If True remove duplicated lines,
+ else (the default) returns all lines.
+ :return: The indices to draw the edges of the triangles as lines.
+ :rtype: 1D numpy.ndarray of uint16 or uint32.
+ """
+ # Makes sure indices ar packed by triangle
+ triangleIndices = triangleIndices.reshape(-1, 3)
+
+ # Pack line indices by triangle and by edge
+ lineindices = numpy.empty((len(triangleIndices), 3, 2),
+ dtype=triangleIndices.dtype)
+ lineindices[:, 0] = triangleIndices[:, :2] # edge = t0, t1
+ lineindices[:, 1] = triangleIndices[:, 1:] # edge =t1, t2
+ lineindices[:, 2] = triangleIndices[:, ::2] # edge = t0, t2
+
+ if unicity:
+ lineindices = _uniqueAlongLastAxis(lineindices.reshape(-1, 2))
+
+ # Make sure it is 1D
+ lineindices.shape = -1
+
+ return lineindices
+
+
+def verticesNormalsToLines(vertices, normals, scale=1.):
+ """Return vertices of lines representing normals at given positions.
+
+ :param vertices: Positions of the points.
+ :type vertices: numpy.ndarray with shape: (nbPoints, 3)
+ :param normals: Corresponding normals at the points.
+ :type normals: numpy.ndarray with shape: (nbPoints, 3)
+ :param float scale: The scale factor to apply to normals.
+ :returns: Array of vertices to draw corresponding lines.
+ :rtype: numpy.ndarray with shape: (nbPoints * 2, 3)
+ """
+ linevertices = numpy.empty((len(vertices) * 2, 3), dtype=vertices.dtype)
+ linevertices[0::2] = vertices
+ linevertices[1::2] = vertices + scale * normals
+ return linevertices
+
+
+def unindexArrays(mode, indices, *arrays):
+ """Convert indexed GL primitives to unindexed ones.
+
+ Given indices in arrays and the OpenGL primitive they represent,
+ return the unindexed equivalent.
+
+ :param str mode:
+ Kind of primitive represented by indices.
+ In: points, lines, line_strip, loop, triangles, triangle_strip, fan.
+ :param indices: Indices in other arrays
+ :type indices: numpy.ndarray of dimension 1.
+ :param arrays: Remaining arguments are arrays to convert
+ :return: Converted arrays
+ :rtype: tuple of numpy.ndarray
+ """
+ indices = numpy.array(indices, copy=False)
+
+ assert mode in ('points',
+ 'lines', 'line_strip', 'loop',
+ 'triangles', 'triangle_strip', 'fan')
+
+ if mode in ('lines', 'line_strip', 'loop'):
+ assert len(indices) >= 2
+ elif mode in ('triangles', 'triangle_strip', 'fan'):
+ assert len(indices) >= 3
+
+ assert indices.min() >= 0
+ max_index = indices.max()
+ for data in arrays:
+ assert len(data) >= max_index
+
+ if mode == 'line_strip':
+ unpacked = numpy.empty((2 * (len(indices) - 1),), dtype=indices.dtype)
+ unpacked[0::2] = indices[:-1]
+ unpacked[1::2] = indices[1:]
+ indices = unpacked
+
+ elif mode == 'loop':
+ unpacked = numpy.empty((2 * len(indices),), dtype=indices.dtype)
+ unpacked[0::2] = indices
+ unpacked[1:-1:2] = indices[1:]
+ unpacked[-1] = indices[0]
+ indices = unpacked
+
+ elif mode == 'triangle_strip':
+ unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
+ unpacked[0::3] = indices[:-2]
+ unpacked[1::3] = indices[1:-1]
+ unpacked[2::3] = indices[2:]
+ indices = unpacked
+
+ elif mode == 'fan':
+ unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
+ unpacked[0::3] = indices[0]
+ unpacked[1::3] = indices[1:-1]
+ unpacked[2::3] = indices[2:]
+ indices = unpacked
+
+ return tuple(numpy.ascontiguousarray(data[indices]) for data in arrays)
+
+
+def trianglesNormal(positions):
+ """Return normal for each triangle.
+
+ :param positions: Serie of triangle's corners
+ :type positions: numpy.ndarray of shape (NbTriangles*3, 3)
+ :return: Normals corresponding to each position.
+ :rtype: numpy.ndarray of shape (NbTriangles, 3)
+ """
+ assert positions.ndim == 2
+ assert positions.shape[1] == 3
+
+ positions = numpy.array(positions, copy=False).reshape(-1, 3, 3)
+
+ normals = numpy.cross(positions[:, 1] - positions[:, 0],
+ positions[:, 2] - positions[:, 0])
+
+ # Normalize normals
+ if numpy.version.version < '1.8.0':
+ # Debian 7 support: numpy.linalg.norm has no axis argument
+ norms = numpy.array(tuple(numpy.linalg.norm(vec) for vec in normals),
+ dtype=normals.dtype)
+ else:
+ norms = numpy.linalg.norm(normals, axis=1)
+ norms[norms == 0] = 1
+
+ return normals / norms.reshape(-1, 1)
+
+
+# grid ########################################################################
+
+def gridVertices(dim0Array, dim1Array, dtype):
+ """Generate an array of 2D positions from 2 arrays of 1D coordinates.
+
+ :param dim0Array: 1D array-like of coordinates along the first dimension.
+ :param dim1Array: 1D array-like of coordinates along the second dimension.
+ :param numpy.dtype dtype: Data type of the output array.
+ :return: Array of grid coordinates.
+ :rtype: numpy.ndarray with shape: (len(dim0Array), len(dim1Array), 2)
+ """
+ grid = numpy.empty((len(dim0Array), len(dim1Array), 2), dtype=dtype)
+ grid.T[0, :, :] = dim0Array
+ grid.T[1, :, :] = numpy.array(dim1Array, copy=False)[:, None]
+ return grid
+
+
+def triangleStripGridIndices(dim0, dim1):
+ """Generate indices to draw a grid of vertices as a triangle strip.
+
+ Vertices are expected to be stored as row-major (i.e., C contiguous).
+
+ :param int dim0: The number of rows of vertices.
+ :param int dim1: The number of columns of vertices.
+ :return: The vertex indices
+ :rtype: 1D numpy.ndarray of uint32
+ """
+ assert dim0 >= 2
+ assert dim1 >= 2
+
+ # Filling a row of squares +
+ # an index before and one after for degenerated triangles
+ indices = numpy.empty((dim0 - 1, 2 * (dim1 + 1)), dtype=numpy.uint32)
+
+ # Init indices with minimum indices for each row of squares
+ indices[:] = (dim1 * numpy.arange(dim0 - 1, dtype=numpy.uint32))[:, None]
+
+ # Update indices with offset per row of squares
+ offset = numpy.arange(dim1, dtype=numpy.uint32)
+ indices[:, 1:-1:2] += offset
+ offset += dim1
+ indices[:, 2::2] += offset
+ indices[:, -1] += offset[-1]
+
+ # Remove extra indices for degenerated triangles before returning
+ return indices.ravel()[1:-1]
+
+ # Alternative:
+ # indices = numpy.zeros(2 * dim1 * (dim0 - 1) + 2 * (dim0 - 2),
+ # dtype=numpy.uint32)
+ #
+ # offset = numpy.arange(dim1, dtype=numpy.uint32)
+ # for d0Index in range(dim0 - 1):
+ # start = 2 * d0Index * (dim1 + 1)
+ # end = start + 2 * dim1
+ # if d0Index != 0:
+ # indices[start - 2] = offset[-1]
+ # indices[start - 1] = offset[0]
+ # indices[start:end:2] = offset
+ # offset += dim1
+ # indices[start + 1:end:2] = offset
+ # return indices
+
+
+def linesGridIndices(dim0, dim1):
+ """Generate indices to draw a grid of vertices as lines.
+
+ Vertices are expected to be stored as row-major (i.e., C contiguous).
+
+ :param int dim0: The number of rows of vertices.
+ :param int dim1: The number of columns of vertices.
+ :return: The vertex indices.
+ :rtype: 1D numpy.ndarray of uint32
+ """
+ # Horizontal and vertical lines
+ nbsegmentalongdim1 = 2 * (dim1 - 1)
+ nbsegmentalongdim0 = 2 * (dim0 - 1)
+
+ indices = numpy.empty(nbsegmentalongdim1 * dim0 +
+ nbsegmentalongdim0 * dim1,
+ dtype=numpy.uint32)
+
+ # Line indices over dim0
+ onedim1line = (numpy.arange(nbsegmentalongdim1,
+ dtype=numpy.uint32) + 1) // 2
+ indices[:dim0 * nbsegmentalongdim1] = \
+ (dim1 * numpy.arange(dim0, dtype=numpy.uint32)[:, None] +
+ onedim1line[None, :]).ravel()
+
+ # Line indices over dim1
+ onedim0line = (numpy.arange(nbsegmentalongdim0,
+ dtype=numpy.uint32) + 1) // 2
+ indices[dim0 * nbsegmentalongdim1:] = \
+ (numpy.arange(dim1, dtype=numpy.uint32)[:, None] +
+ dim1 * onedim0line[None, :]).ravel()
+
+ return indices
+
+
+# intersection ################################################################
+
+def angleBetweenVectors(refVector, vectors, norm=None):
+ """Return the angle between 2 vectors.
+
+ :param refVector: Coordinates of the reference vector.
+ :type refVector: numpy.ndarray of shape: (NCoords,)
+ :param vectors: Coordinates of the vector(s) to get angle from reference.
+ :type vectors: numpy.ndarray of shape: (NCoords,) or (NbVector, NCoords)
+ :param norm: A direction vector giving an orientation to the angles
+ or None.
+ :returns: The angles in radians in [0, pi] if norm is None
+ else in [0, 2pi].
+ :rtype: float or numpy.ndarray of shape (NbVectors,)
+ """
+ singlevector = len(vectors.shape) == 1
+ if singlevector: # Make it a 2D array for the computation
+ vectors = vectors.reshape(1, -1)
+
+ assert len(refVector.shape) == 1
+ assert len(vectors.shape) == 2
+ assert len(refVector) == vectors.shape[1]
+
+ # Normalize vectors
+ refVector /= numpy.linalg.norm(refVector)
+ vectors = numpy.array([v / numpy.linalg.norm(v) for v in vectors])
+
+ dots = numpy.sum(refVector * vectors, axis=-1)
+ angles = numpy.arccos(numpy.clip(dots, -1., 1.))
+ if norm is not None:
+ signs = numpy.sum(norm * numpy.cross(refVector, vectors), axis=-1) < 0.
+ angles[signs] = numpy.pi * 2. - angles[signs]
+
+ return angles[0] if singlevector else angles
+
+
+def segmentPlaneIntersect(s0, s1, planeNorm, planePt):
+ """Compute the intersection of a segment with a plane.
+
+ :param s0: First end of the segment
+ :type s0: 1D numpy.ndarray-like of length 3
+ :param s1: Second end of the segment
+ :type s1: 1D numpy.ndarray-like of length 3
+ :param planeNorm: Normal vector of the plane.
+ :type planeNorm: numpy.ndarray of shape: (3,)
+ :param planePt: A point of the plane.
+ :type planePt: numpy.ndarray of shape: (3,)
+ :return: The intersection points. The number of points goes
+ from 0 (no intersection) to 2 (segment in the plane)
+ :rtype: list of numpy.ndarray
+ """
+ s0, s1 = numpy.asarray(s0), numpy.asarray(s1)
+
+ segdir = s1 - s0
+ dotnormseg = numpy.dot(planeNorm, segdir)
+ if dotnormseg == 0:
+ # line and plane are parallels
+ if numpy.dot(planeNorm, planePt - s0) == 0: # segment is in plane
+ return [s0, s1]
+ else: # No intersection
+ return []
+
+ alpha = - numpy.dot(planeNorm, s0 - planePt) / dotnormseg
+ if 0. <= alpha <= 1.: # Intersection with segment
+ return [s0 + alpha * segdir]
+ else: # intersection outside segment
+ return []
+
+
+def boxPlaneIntersect(boxVertices, boxLineIndices, planeNorm, planePt):
+ """Return intersection points between a box and a plane.
+
+ :param boxVertices: Position of the corners of the box.
+ :type boxVertices: numpy.ndarray with shape: (8, 3)
+ :param boxLineIndices: Indices of the box edges.
+ :type boxLineIndices: numpy.ndarray-like with shape: (12, 2)
+ :param planeNorm: Normal vector of the plane.
+ :type planeNorm: numpy.ndarray of shape: (3,)
+ :param planePt: A point of the plane.
+ :type planePt: numpy.ndarray of shape: (3,)
+ :return: The found intersection points
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ segments = numpy.take(boxVertices, boxLineIndices, axis=0)
+
+ points = set() # Gather unique intersection points
+ for seg in segments:
+ for point in segmentPlaneIntersect(seg[0], seg[1], planeNorm, planePt):
+ points.add(tuple(point))
+ points = numpy.array(list(points))
+
+ if len(points) <= 2:
+ return numpy.array(())
+ elif len(points) == 3:
+ return points
+ else: # len(points) > 3
+ # Order point to have a polyline lying on the unit cube's faces
+ vectors = points - numpy.mean(points, axis=0)
+ angles = angleBetweenVectors(vectors[0], vectors, planeNorm)
+ points = numpy.take(points, numpy.argsort(angles), axis=0)
+ return points
+
+
+# Plane #######################################################################
+
+class Plane(event.Notifier):
+ """Object handling a plane and notifying plane changes.
+
+ :param point: A point on the plane.
+ :type point: 3-tuple of float.
+ :param normal: Normal of the plane.
+ :type normal: 3-tuple of float.
+ """
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ super(Plane, self).__init__()
+
+ assert len(point) == 3
+ self._point = numpy.array(point, copy=True, dtype=numpy.float32)
+ assert len(normal) == 3
+ self._normal = numpy.array(normal, copy=True, dtype=numpy.float32)
+ self.notify()
+
+ def setPlane(self, point=None, normal=None):
+ """Set plane point and normal and notify.
+
+ :param point: A point on the plane.
+ :type point: 3-tuple of float or None.
+ :param normal: Normal of the plane.
+ :type normal: 3-tuple of float or None.
+ """
+ planechanged = False
+
+ if point is not None:
+ assert len(point) == 3
+ point = numpy.array(point, copy=True, dtype=numpy.float32)
+ if not numpy.all(numpy.equal(self._point, point)):
+ self._point = point
+ planechanged = True
+
+ if normal is not None:
+ assert len(normal) == 3
+ normal = numpy.array(normal, copy=True, dtype=numpy.float32)
+
+ norm = numpy.linalg.norm(normal)
+ if norm != 0.:
+ normal /= norm
+
+ if not numpy.all(numpy.equal(self._normal, normal)):
+ self._normal = normal
+ planechanged = True
+
+ if planechanged:
+ _logger.debug('Plane updated:\n\tpoint: %s\n\tnormal: %s',
+ str(self._point), str(self._normal))
+ self.notify()
+
+ @property
+ def point(self):
+ """A point on the plane."""
+ return self._point.copy()
+
+ @point.setter
+ def point(self, point):
+ self.setPlane(point=point)
+
+ @property
+ def normal(self):
+ """The (normalized) normal of the plane."""
+ return self._normal.copy()
+
+ @normal.setter
+ def normal(self, normal):
+ self.setPlane(normal=normal)
+
+ @property
+ def parameters(self):
+ """Plane equation parameters: a*x + b*y + c*z + d = 0."""
+ return numpy.append(self._normal,
+ - numpy.dot(self._point, self._normal))
+
+ @parameters.setter
+ def parameters(self, parameters):
+ assert len(parameters) == 4
+ parameters = numpy.array(parameters, dtype=numpy.float32)
+
+ # Normalize normal
+ norm = numpy.linalg.norm(parameters[:3])
+ if norm != 0:
+ parameters /= norm
+
+ normal = parameters[:3]
+ point = - parameters[3] * normal
+ self.setPlane(point, normal)
+
+ @property
+ def isPlane(self):
+ """True if a plane is defined (i.e., ||normal|| != 0)."""
+ return numpy.any(self.normal != 0.)
+
+ def move(self, step):
+ """Move the plane of step along the normal."""
+ self.point += step * self.normal
diff --git a/silx/gui/plot3d/scene/viewport.py b/silx/gui/plot3d/scene/viewport.py
new file mode 100644
index 0000000..83cda43
--- /dev/null
+++ b/silx/gui/plot3d/scene/viewport.py
@@ -0,0 +1,492 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class to control a viewport on the rendering window.
+
+The :class:`Viewport` describes a Viewport rendering a scene.
+The attribute :attr:`scene` is the root group of the scene tree.
+:class:`RenderContext` handles the current state during rendering.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import numpy
+
+from silx.gui.plot.Colors import rgba
+
+from ..._glutils import gl
+
+from . import camera
+from . import event
+from . import transform
+from .function import DirectionalLight, ClippingPlane
+
+
+class RenderContext(object):
+ """Handle a current rendering context.
+
+ An instance of this class is passed to rendering method through
+ the scene during render.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param Viewport viewport: The viewport doing the rendering.
+ :param Context glContext: The operating system OpenGL context in use.
+ """
+
+ def __init__(self, viewport, glContext):
+ self._viewport = viewport
+ self._glContext = glContext
+ self._transformStack = [viewport.camera.extrinsic]
+ self._clipPlane = ClippingPlane(normal=(0., 0., 0.))
+
+ @property
+ def viewport(self):
+ """Viewport doing the current rendering"""
+ return self._viewport
+
+ @property
+ def glCtx(self):
+ """The OpenGL context in use"""
+ return self._glContext
+
+ @property
+ def objectToCamera(self):
+ """The current transform from object to camera coords.
+
+ Do not modify.
+ """
+ return self._transformStack[-1]
+
+ @property
+ def projection(self):
+ """Projection transform.
+
+ Do not modify.
+ """
+ return self.viewport.camera.intrinsic
+
+ @property
+ def objectToNDC(self):
+ """The transform from object to NDC (this includes projection).
+
+ Do not modify.
+ """
+ return transform.StaticTransformList(
+ (self.projection, self.objectToCamera))
+
+ def pushTransform(self, transform_, multiply=True):
+ """Push a :class:`Transform` on the transform stack.
+
+ :param Transform transform_: The transform to add to the stack.
+ :param bool multiply:
+ True (the default) to multiply with the top of the stack,
+ False to push the transform as is without multiplication.
+ """
+ if multiply:
+ assert len(self._transformStack) >= 1
+ transform_ = transform.StaticTransformList(
+ (self._transformStack[-1], transform_))
+
+ self._transformStack.append(transform_)
+
+ def popTransform(self):
+ """Pop the transform on top of the stack.
+
+ :return: The Transform that is popped from the stack.
+ """
+ assert len(self._transformStack) > 1
+ return self._transformStack.pop()
+
+ @property
+ def clipper(self):
+ """The current clipping plane
+ """
+ return self._clipPlane
+
+ def setClipPlane(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ """Set the clipping plane to use
+
+ For now only handles a single clipping plane.
+
+ :param point: A point of the plane
+ :type point: 3-tuple of float
+ :param normal: Normal vector of the plane or (0, 0, 0) for no clipping
+ :type normal: 3-tuple of float
+ """
+ self._clipPlane = ClippingPlane(point, normal)
+
+
+class Viewport(event.Notifier):
+ """Rendering a single scene through a camera in part of a framebuffer.
+
+ :param int framebuffer: The framebuffer ID this viewport is rendering into
+ """
+
+ def __init__(self, framebuffer=0):
+ from . import Group # Here to avoid cyclic import
+ super(Viewport, self).__init__()
+ self._dirty = True
+ self._origin = 0, 0
+ self._size = 1, 1
+ self._framebuffer = int(framebuffer)
+ self.scene = Group() # The stuff to render, add overlaid scenes?
+ self.scene._setParent(self)
+ self.scene.addListener(self._changed)
+ self._background = 0., 0., 0., 1.
+ self._camera = camera.Camera(fovy=30., near=1., far=100.,
+ position=(0., 0., 12.))
+ self._camera.addListener(self._changed)
+ self._transforms = transform.TransformList([self._camera])
+
+ self._light = DirectionalLight(direction=(0., 0., -1.),
+ ambient=(0.3, 0.3, 0.3),
+ diffuse=(0.7, 0.7, 0.7))
+ self._light.addListener(self._changed)
+
+ @property
+ def transforms(self):
+ """Proxy of camera transforms.
+
+ Do not modify the list.
+ """
+ return self._transforms
+
+ def _changed(self, *args, **kwargs):
+ """Callback handling scene updates"""
+ self._dirty = True
+ self.notify()
+
+ @property
+ def dirty(self):
+ """True if scene is dirty and needs redisplay."""
+ return self._dirty
+
+ def resetDirty(self):
+ """Mark the scene as not being dirty.
+
+ To call after rendering.
+ """
+ self._dirty = False
+
+ @property
+ def background(self):
+ """Background color of the viewport (4-tuple of float in [0, 1]"""
+ return self._background
+
+ @background.setter
+ def background(self, color):
+ color = rgba(color)
+ if self._background != color:
+ self._background = color
+ self._changed()
+
+ @property
+ def camera(self):
+ """The camera used to render the scene."""
+ return self._camera
+
+ @property
+ def light(self):
+ """The light used to render the scene."""
+ return self._light
+
+ @property
+ def origin(self):
+ """Origin (ox, oy) of the viewport in pixels"""
+ return self._origin
+
+ @origin.setter
+ def origin(self, origin):
+ ox, oy = origin
+ origin = int(ox), int(oy)
+ if origin != self._origin:
+ self._origin = origin
+ self._changed()
+
+ @property
+ def size(self):
+ """Size (width, height) of the viewport in pixels"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ w, h = size
+ size = int(w), int(h)
+ if size != self._size:
+ self._size = size
+
+ self.camera.intrinsic.size = size
+ self._changed()
+
+ @property
+ def shape(self):
+ """Shape (height, width) of the viewport in pixels.
+
+ This is a convenient wrapper to the inverse of size.
+ """
+ return self._size[1], self._size[0]
+
+ @shape.setter
+ def shape(self, shape):
+ self.size = shape[1], shape[0]
+
+ @property
+ def framebuffer(self):
+ """The framebuffer ID this viewport is rendering into (int)"""
+ return self._framebuffer
+
+ @framebuffer.setter
+ def framebuffer(self, framebuffer):
+ self._framebuffer = int(framebuffer)
+
+ def render(self, glContext):
+ """Perform the rendering of the viewport
+
+ :param Context glContext: The context used for rendering"""
+ # Get a chance to run deferred delete
+ glContext.cleanGLGarbage()
+
+ # OpenGL set-up: really need to be done once
+ ox, oy = self.origin
+ w, h = self.size
+ gl.glViewport(ox, oy, w, h)
+
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+ gl.glScissor(ox, oy, w, h)
+
+ gl.glEnable(gl.GL_BLEND)
+ gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
+
+ gl.glEnable(gl.GL_DEPTH_TEST)
+ gl.glDepthFunc(gl.GL_LEQUAL)
+ gl.glDepthRange(0., 1.)
+
+ # gl.glEnable(gl.GL_POLYGON_OFFSET_FILL)
+ # gl.glPolygonOffset(1., 1.)
+
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ gl.glClearColor(*self.background)
+
+ # Prepare OpenGL
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT |
+ gl.GL_STENCIL_BUFFER_BIT |
+ gl.GL_DEPTH_BUFFER_BIT)
+
+ ctx = RenderContext(self, glContext)
+ self.scene.render(ctx)
+ self.scene.postRender(ctx)
+
+ def adjustCameraDepthExtent(self):
+ """Update camera depth extent to fit the scene bounds.
+
+ Only near and far planes are updated.
+ The scene might still not be fully visible
+ (e.g., if spanning behind the viewpoint with perspective projection).
+ """
+ bounds = self.scene.bounds(transformed=True)
+ bounds = self.camera.extrinsic.transformBounds(bounds)
+
+ if isinstance(self.camera.intrinsic, transform.Perspective):
+ # This needs to be reworked
+ zbounds = - bounds[:, 2]
+ zextent = max(numpy.fabs(zbounds[0] - zbounds[1]), 0.0001)
+ near = max(zextent / 1000., 0.95 * zbounds[1])
+ far = max(near + 0.1, 1.05 * zbounds[0])
+
+ self.camera.intrinsic.setDepthExtent(near, far)
+ elif isinstance(self.camera.intrinsic, transform.Orthographic):
+ # Makes sure z bounds are included
+ border = max(abs(bounds[:, 2]))
+ self.camera.intrinsic.setDepthExtent(-border, border)
+ else:
+ raise RuntimeError('Unsupported camera', self.camera.intrinsic)
+
+ def resetCamera(self):
+ """Change camera to have the whole scene in the viewing frustum.
+
+ It updates the camera position and depth extent.
+ Camera sight direction and up are not affected.
+ """
+ self.camera.resetCamera(self.scene.bounds(transformed=True))
+
+ def orbitCamera(self, direction, angle=1.):
+ """Rotate the camera around center of the scene.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param float angle: he angle in degrees of the rotation.
+ """
+ bounds = self.scene.bounds(transformed=True)
+ center = 0.5 * (bounds[0] + bounds[1])
+ self.camera.orbit(direction, center, angle)
+
+ def moveCamera(self, direction, step=0.1):
+ """Move the camera relative to the image plane.
+
+ :param str direction: Direction relative to image plane.
+ One of: 'up', 'down', 'left', 'right',
+ 'forward', 'backward'.
+ :param float step: The ratio of data to step for each pan.
+ """
+ bounds = self.scene.bounds(transformed=True)
+ bounds = self.camera.extrinsic.transformBounds(bounds)
+ center = 0.5 * (bounds[0] + bounds[1])
+ ndcCenter = self.camera.intrinsic.transformPoint(
+ center, perspectiveDivide=True)
+
+ step *= 2. # NDC has size 2
+
+ if direction == 'up':
+ ndcCenter[1] -= step
+ elif direction == 'down':
+ ndcCenter[1] += step
+
+ elif direction == 'right':
+ ndcCenter[0] -= step
+ elif direction == 'left':
+ ndcCenter[0] += step
+
+ elif direction == 'forward':
+ ndcCenter[2] += step
+ elif direction == 'backward':
+ ndcCenter[2] -= step
+
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ newCenter = self.camera.intrinsic.transformPoint(
+ ndcCenter, direct=False, perspectiveDivide=True)
+
+ self.camera.move(direction, numpy.linalg.norm(newCenter - center))
+
+ def windowToNdc(self, winX, winY, checkInside=True):
+ """Convert position from window to normalized device coordinates.
+
+ If window coordinates are int, they are moved half a pixel
+ to be positioned at the center of pixel.
+
+ :param winX: X window coord, origin left.
+ :param winY: Y window coord, origin top.
+ :param bool checkInside: If True, returns None if position is
+ outside viewport.
+ :return: (x, y) Normalize device coordinates in [-1, 1] or None.
+ Origin center, x to the right, y goes upward.
+ """
+ ox, oy = self._origin
+ width, height = self.size
+
+ # If int, move it to the center of pixel
+ if isinstance(winX, int):
+ winX += 0.5
+ if isinstance(winY, int):
+ winY += 0.5
+
+ x, y = winX - ox, winY - oy
+
+ if checkInside and (x < 0. or x > width or y < 0. or y > height):
+ return None # Out of viewport
+
+ ndcx = 2. * x / float(width) - 1.
+ ndcy = 1. - 2. * y / float(height)
+ return ndcx, ndcy
+
+ def ndcToWindow(self, ndcX, ndcY, checkInside=True):
+ """Convert position from normalized device coordinates (NDC) to window.
+
+ :param float ndcX: X NDC coord.
+ :param float ndcY: Y NDC coord.
+ :param bool checkInside: If True, returns None if position is
+ outside viewport.
+ :return: (x, y) window coordinates or None.
+ Origin top-left, x to the right, y goes downward.
+ """
+ if (checkInside and
+ (ndcX < -1. or ndcX > 1. or ndcY < -1. or ndcY > 1.)):
+ return None # Outside viewport
+
+ ox, oy = self._origin
+ width, height = self.size
+
+ winx = ox + width * 0.5 * (ndcX + 1.)
+ winy = oy + height * 0.5 * (1. - ndcY)
+ return winx, winy
+
+ def _pickNdcZGL(self, x, y):
+ """Retrieve depth from depth buffer and return corresponding NDC Z.
+
+ :param int x: In pixels in window coordinates, origin left.
+ :param int y: In pixels in window coordinates, origin top.
+ :return: Normalize device Z coordinate of depth in [-1, 1]
+ or None if outside viewport.
+ :rtype: float or None
+ """
+ ox, oy = self._origin
+ width, height = self.size
+
+ x = int(x)
+ y = height - int(y) # Invert y coord
+
+ if x < ox or x > ox + width or y < oy or y > oy + height:
+ # Outside viewport
+ return None
+
+ # Get depth from depth buffer in [0., 1.]
+ # Bind used framebuffer to get depth
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebuffer)
+ depth = gl.glReadPixels(
+ x, y, 1, 1, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)[0]
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
+ # This is not GL|ES friendly
+
+ # Z in NDC in [-1., 1.]
+ return float(depth) * 2. - 1.
+
+ def _getXZYGL(self, x, y):
+ ndc = self.windowToNdc(x, y)
+ if ndc is None:
+ return None # Outside viewport
+ ndcz = self._pickNdcZGL(x, y)
+ ndcpos = numpy.array((ndc[0], ndc[1], ndcz, 1.), dtype=numpy.float32)
+
+ camerapos = self.camera.intrinsic.transformPoint(
+ ndcpos, direct=False, perspectiveDivide=True)
+
+ scenepos = self.camera.extrinsic.transformPoint(camerapos,
+ direct=False)
+ return scenepos[:3]
+
+ def pick(self, x, y):
+ pass
+ # ndcX, ndcY = self.windowToNdc(x, y)
+ # ndcNearPt = ndcX, ndcY, -1.
+ # ndcFarPT = ndcX, ndcY, 1.
diff --git a/silx/gui/plot3d/scene/window.py b/silx/gui/plot3d/scene/window.py
new file mode 100644
index 0000000..ad7e6e5
--- /dev/null
+++ b/silx/gui/plot3d/scene/window.py
@@ -0,0 +1,420 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class for Viewports rendering on the screen.
+
+The :class:`Window` renders a list of Viewports in the current framebuffer.
+The rendering can be performed in an off-screen framebuffer that is only
+updated when the scene has changed and not each time Qt is requiring a repaint.
+
+The :class:`Context` and :class:`ContextGL2` represent the operating system
+OpenGL context and handle OpenGL resources.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+
+import weakref
+import numpy
+
+from ..._glutils import gl
+from ... import _glutils
+
+from . import event
+
+
+class Context(object):
+ """Correspond to an operating system OpenGL context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+
+ def __init__(self, glContextHandle):
+ self._context = glContextHandle
+ self._isCurrent = False
+ self._devicePixelRatio = 1.0
+
+ @property
+ def isCurrent(self):
+ """Whether this OpenGL context is the current one or not."""
+ return self._isCurrent
+
+ def setCurrent(self, isCurrent=True):
+ """Set the state of the OpenGL context to reflect OpenGL state.
+
+ This should not be called from the scene graph, only in the
+ wrapper that handle the OpenGL context to reflect its state.
+
+ :param bool isCurrent: The state of the system OpenGL context.
+ """
+ self._isCurrent = bool(isCurrent)
+
+ @property
+ def devicePixelRatio(self):
+ """Ratio between device and device independent pixels (float)
+
+ This is useful for font rendering.
+ """
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ assert ratio > 0
+ self._devicePixelRatio = float(ratio)
+
+ def __enter__(self):
+ self.setCurrent(True)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.setCurrent(False)
+
+ @property
+ def glContext(self):
+ """The handle to the OpenGL context provided by the system."""
+ return self._context
+
+ def cleanGLGarbage(self):
+ """This is releasing OpenGL resource that are no longer used."""
+ pass
+
+
+class ContextGL2(Context):
+ """Handle a system GL2 context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+ def __init__(self, glContextHandle):
+ super(ContextGL2, self).__init__(glContextHandle)
+
+ self._programs = {} # GL programs already compiled
+ self._vbos = {} # GL Vbos already set
+ self._vboGarbage = [] # Vbos waiting to be discarded
+
+ # programs
+
+ def prog(self, vertexShaderSrc, fragmentShaderSrc):
+ """Cache program within context.
+
+ WARNING: No clean-up.
+ """
+ assert self.isCurrent
+ key = vertexShaderSrc, fragmentShaderSrc
+ prog = self._programs.get(key, None)
+ if prog is None:
+ prog = _glutils.Program(vertexShaderSrc, fragmentShaderSrc)
+ self._programs[key] = prog
+ return prog
+
+ # VBOs
+
+ def makeVbo(self, data=None, sizeInBytes=None,
+ usage=None, target=None):
+ """Create a VBO in this context with the data.
+
+ Current limitations:
+
+ - One array per VBO
+ - Do not support sharing VertexBuffer across VboAttrib
+
+ Automatically discards the VBO when the returned
+ :class:`VertexBuffer` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param int sizeInBytes: Size of the VBO or None.
+ It should be <= data.nbytes if both are given.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :return: The VertexBuffer created in this context.
+ """
+ assert self.isCurrent
+ vbo = _glutils.VertexBuffer(data, sizeInBytes, usage, target)
+ vboref = weakref.ref(vbo, self._deadVbo)
+ # weakref is hashable as far as target is
+ self._vbos[vboref] = vbo.name
+ return vbo
+
+ def makeVboAttrib(self, data, usage=None, target=None):
+ """Create a VBO from data and returns the associated VBOAttrib.
+
+ Automatically discards the VBO when the returned
+ :class:`VBOAttrib` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :returns: A VBOAttrib instance created in this context.
+ """
+ assert self.isCurrent
+ vbo = self.makeVbo(data, usage=usage, target=target)
+
+ assert len(data.shape) <= 2
+ dimension = 1 if len(data.shape) == 1 else data.shape[1]
+
+ return _glutils.VertexBufferAttrib(
+ vbo,
+ type_=_glutils.numpyToGLType(data.dtype),
+ size=data.shape[0],
+ dimension=dimension,
+ offset=0,
+ stride=0)
+
+ def _deadVbo(self, vboRef):
+ """Callback handling dead VBOAttribs."""
+ vboid = self._vbos.pop(vboRef)
+ if self.isCurrent:
+ # Direct delete if context is active
+ gl.glDeleteBuffers(vboid)
+ else:
+ # Deferred VBO delete if context is not active
+ self._vboGarbage.append(vboid)
+
+ def cleanGLGarbage(self):
+ """Delete OpenGL resources that are pending for destruction.
+
+ This requires the associated OpenGL context to be active.
+ This is meant to be called before rendering.
+ """
+ assert self.isCurrent
+ if self._vboGarbage:
+ vboids = self._vboGarbage
+ gl.glDeleteBuffers(vboids)
+ self._vboGarbage = []
+
+
+class Window(event.Notifier):
+ """OpenGL Framebuffer where to render viewports
+
+ :param str mode: Rendering mode to use:
+
+ - 'direct' to render everything for each render call
+ - 'framebuffer' to cache viewport rendering in a texture and
+ update the texture only when needed.
+ """
+
+ _position = numpy.array(((-1., -1., 0., 0.),
+ (1., -1., 1., 0.),
+ (-1., 1., 0., 1.),
+ (1., 1., 1., 1.)),
+ dtype=numpy.float32)
+
+ _shaders = ("""
+ attribute vec4 position;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_Position = vec4(position.x, position.y, 0., 1.);
+ textureCoord = position.zw;
+ }
+ """,
+ """
+ uniform sampler2D texture;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_FragColor = texture2D(texture, textureCoord);
+ }
+ """)
+
+ def __init__(self, mode='framebuffer'):
+ super(Window, self).__init__()
+ self._dirty = True
+ self._size = 0, 0
+ self._contexts = {} # To map system GL context id to Context objects
+ self._viewports = event.NotifierList()
+ self._viewports.addListener(self._updated)
+ self._framebufferid = 0
+ self._framebuffers = {} # Cache of framebuffers
+
+ assert mode in ('direct', 'framebuffer')
+ self._isframebuffer = mode == 'framebuffer'
+
+ @property
+ def dirty(self):
+ """True if this object or any attached viewports is dirty."""
+ for viewport in self._viewports:
+ if viewport.dirty:
+ return True
+ return self._dirty
+
+ @property
+ def size(self):
+ """Size (width, height) of the window in pixels"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ w, h = size
+ size = int(w), int(h)
+ if size != self._size:
+ self._size = size
+ self._dirty = True
+ self.notify()
+
+ @property
+ def shape(self):
+ """Shape (height, width) of the window in pixels.
+
+ This is a convenient wrapper to the reverse of size.
+ """
+ return self._size[1], self._size[0]
+
+ @shape.setter
+ def shape(self, shape):
+ self.size = shape[1], shape[0]
+
+ @property
+ def viewports(self):
+ """List of viewports to render in the corresponding framebuffer"""
+ return self._viewports
+
+ @viewports.setter
+ def viewports(self, iterable):
+ self._viewports.removeListener(self._updated)
+ self._viewports = event.NotifierList(iterable)
+ self._viewports.addListener(self._updated)
+ self._dirty = True
+
+ def _updated(self, source, *args, **kwargs):
+ if source is not self:
+ self._dirty = True
+ self.notify(*args, **kwargs)
+
+ framebufferid = property(lambda self: self._framebufferid,
+ doc="Framebuffer ID used to perform rendering")
+
+ def grab(self, glcontext):
+ """Returns the raster of the scene as an RGB numpy array
+
+ :returns: OpenGL scene RGB bitmap
+ :rtype: numpy.ndarray of uint8 of dimension (height, width, 3)
+ """
+ height, width = self.shape
+ image = numpy.empty((height, width, 3), dtype=numpy.uint8)
+
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(
+ 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ image = numpy.flipud(image)
+
+ return numpy.array(image, copy=False, order='C')
+
+ def render(self, glcontext, devicePixelRatio):
+ """Perform the rendering of attached viewports
+
+ :param glcontext: System identifier of the OpenGL context
+ :param float devicePixelRatio:
+ Ratio between device and device-independent pixels
+ """
+ if glcontext not in self._contexts:
+ self._contexts[glcontext] = ContextGL2(glcontext) # New context
+
+ with self._contexts[glcontext] as context:
+ context.devicePixelRatio = devicePixelRatio
+ if self._isframebuffer:
+ self._renderWithOffscreenFramebuffer(context)
+ else:
+ self._renderDirect(context)
+
+ self._dirty = False
+
+ def _renderDirect(self, context):
+ """Perform the direct rendering of attached viewports
+
+ :param Context context: Object wrapping OpenGL context
+ """
+ for viewport in self._viewports:
+ viewport.framebuffer = self.framebufferid
+ viewport.render(context)
+ viewport.resetDirty()
+
+ def _renderWithOffscreenFramebuffer(self, context):
+ """Renders viewports in a texture and render this texture on screen.
+
+ The texture is updated only if viewport or size has changed.
+
+ :param ContextGL2 context: Object wrappign OpenGL context
+ """
+ if self.dirty or context not in self._framebuffers:
+ # Need to redraw framebuffer content
+
+ if (context not in self._framebuffers or
+ self._framebuffers[context].shape != self.shape):
+ # Need to rebuild framebuffer
+
+ if context in self._framebuffers:
+ self._framebuffers[context].discard()
+
+ fbo = _glutils.FramebufferTexture(gl.GL_RGBA,
+ shape=self.shape,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+ self._framebuffers[context] = fbo
+ self._framebufferid = fbo.name
+
+ # Render in framebuffer
+ with self._framebuffers[context]:
+ self._renderDirect(context)
+
+ # Render framebuffer texture to screen
+ fbo = self._framebuffers[context]
+ height, width = fbo.shape
+
+ program = context.prog(*self._shaders)
+ program.use()
+
+ gl.glViewport(0, 0, width, height)
+ gl.glDisable(gl.GL_BLEND)
+ gl.glDisable(gl.GL_DEPTH_TEST)
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+ # gl.glScissor(0, 0, width, height)
+ gl.glClearColor(0., 0., 0., 0.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+ gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit)
+ gl.glEnableVertexAttribArray(program.attributes['position'])
+ gl.glVertexAttribPointer(program.attributes['position'],
+ 4,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._position)
+ fbo.texture.bind()
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position))
+ gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
diff --git a/silx/gui/plot3d/setup.py b/silx/gui/plot3d/setup.py
new file mode 100644
index 0000000..b9d626f
--- /dev/null
+++ b/silx/gui/plot3d/setup.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('plot3d', parent_package, top_path)
+ config.add_subpackage('scene')
+ config.add_subpackage('test')
+ config.add_subpackage('utils')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/plot3d/test/__init__.py b/silx/gui/plot3d/test/__init__.py
new file mode 100644
index 0000000..66a2f62
--- /dev/null
+++ b/silx/gui/plot3d/test/__init__.py
@@ -0,0 +1,62 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""plot3d test suite."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/01/2017"
+
+
+import logging
+import os
+import unittest
+
+
+_logger = logging.getLogger(__name__)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+
+ if os.environ.get('WITH_GL_TEST', 'True') == 'False':
+ # Explicitly disabled tests
+ _logger.warning(
+ "silx.gui.plot3d tests disabled (WITH_GL_TEST=False)")
+
+ class SkipPlot3DTest(unittest.TestCase):
+ def runTest(self):
+ self.skipTest(
+ "silx.gui.plot3d tests disabled (WITH_GL_TEST=False)")
+
+ test_suite.addTest(SkipPlot3DTest())
+ return test_suite
+
+ # Import here to avoid loading modules if tests are disabled
+
+ from ..scene import test as test_scene
+
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(test_scene.suite())
+ return test_suite
diff --git a/silx/gui/plot3d/utils/__init__.py b/silx/gui/plot3d/utils/__init__.py
new file mode 100644
index 0000000..99d3e08
--- /dev/null
+++ b/silx/gui/plot3d/utils/__init__.py
@@ -0,0 +1,28 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
diff --git a/silx/gui/plot3d/utils/mng.py b/silx/gui/plot3d/utils/mng.py
new file mode 100644
index 0000000..fe79a52
--- /dev/null
+++ b/silx/gui/plot3d/utils/mng.py
@@ -0,0 +1,121 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides basic writing Mulitple-image Network Graphics files.
+
+It only supports RGB888 images of the same shape stored as
+MNG-VLC (very low complexity) format.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/12/2016"
+
+
+import logging
+import struct
+import zlib
+
+import numpy
+
+_logger = logging.getLogger(__name__)
+
+
+def _png_chunk(name, data):
+ """Return a PNG chunk
+
+ :param str name: Chunk type
+ :param byte data: Chunk payload
+ """
+ length = struct.pack('>I', len(data))
+ name = [char.encode('ascii') for char in name]
+ chunk = struct.pack('cccc', *name) + data
+ crc = struct.pack('>I', zlib.crc32(chunk) & 0xffffffff)
+ return length + chunk + crc
+
+
+def convert(images, nb_images=0, fps=25):
+ """Convert RGB images to MNG-VLC format.
+
+ See http://www.libpng.org/pub/mng/spec/
+ See http://www.libpng.org/pub/png/book/
+ See http://www.libpng.org/pub/png/spec/1.2/
+
+ :param images: iterator of RGB888 images
+ :type images: iterator of numpy.ndarray of dimension 3
+ :param int nb_images: The number of images indicated in the MNG header
+ :param int fps: The frame rate indicated in the MNG header
+ :return: An iterator of MNG chunks as bytes
+ """
+ first_image = True
+
+ for image in images:
+ if first_image:
+ first_image = False
+
+ height, width = image.shape[:2]
+
+ # MNG signature
+ yield b'\x8aMNG\r\n\x1a\n'
+
+ # MHDR chunk: File header
+ yield _png_chunk('MHDR', struct.pack(
+ ">IIIIIII",
+ width,
+ height,
+ fps, # ticks
+ nb_images + 1, # layer count
+ nb_images, # frame count
+ nb_images, # play time
+ 1)) # profile: MNG-VLC no alpha: only least significant bit 1
+
+ assert image.shape == (height, width, 3)
+ assert image.dtype == numpy.dtype('uint8')
+
+ # IHDR chunk: Image header
+ depth = 8 # 8 bit per channel
+ color_type = 2 # 'truecolor' = RGB
+ interlace = 0 # No
+ yield _png_chunk('IHDR', struct.pack(">IIBBBBB",
+ width,
+ height,
+ depth,
+ color_type,
+ 0, 0, interlace))
+
+ # Add filter 'None' before each scanline
+ prepared_data = b'\x00' + b'\x00'.join(
+ line.tostring() for line in image) # TODO optimize that
+ compressed_data = zlib.compress(prepared_data, 8)
+
+ # IDAT chunk: Payload
+ yield _png_chunk('IDAT', compressed_data)
+
+ # IEND chunk: Image footer
+ yield _png_chunk('IEND', b'')
+
+ # MEND chunk: footer
+ yield _png_chunk('MEND', b'')
diff --git a/silx/gui/qt/__init__.py b/silx/gui/qt/__init__.py
new file mode 100644
index 0000000..44daa94
--- /dev/null
+++ b/silx/gui/qt/__init__.py
@@ -0,0 +1,61 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Common wrapper over Python Qt bindings:
+
+- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_,
+- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ or
+- `PySide <http://www.pyside.org>`_.
+
+If a Qt binding is already loaded, it will use it, otherwise the different
+Qt bindings are tried in this order: PyQt4, PySide, PyQt5.
+
+The name of the loaded Qt binding is stored in the BINDING variable.
+
+This module provides a flat namespace over Qt bindings by importing
+all symbols from **QtCore** and **QtGui** packages and if available
+from **QtOpenGL** and **QtSvg** packages.
+For **PyQt5**, it also imports all symbols from **QtWidgets** and
+**QtPrintSupport** packages.
+
+Example of using :mod:`silx.gui.qt` module:
+
+>>> from silx.gui import qt
+>>> app = qt.QApplication([])
+>>> widget = qt.QWidget()
+
+For an alternative solution providing a structured namespace,
+see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which
+provides the namespace of PyQt5 over PyQt4 and PySide.
+"""
+
+import sys
+from ._qt import * # noqa
+from ._utils import * # noqa
+
+
+if sys.platform == "darwin":
+ if BINDING in ["PySide", "PyQt4"]:
+ from . import _macosx
+ _macosx.patch_QUrl_toLocalFile()
diff --git a/silx/gui/qt/_macosx.py b/silx/gui/qt/_macosx.py
new file mode 100644
index 0000000..07f3143
--- /dev/null
+++ b/silx/gui/qt/_macosx.py
@@ -0,0 +1,68 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Patches for Mac OS X
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+def patch_QUrl_toLocalFile():
+ """Apply a monkey-patch on qt.QUrl to allow to reach filename when the URL
+ come from a MIME data from a file drop. Without, `QUrl.toLocalName` with
+ some version of Mac OS X returns a path which looks like
+ `/.file/id=180.112`.
+
+ Qt5 is or will be patch, but Qt4 and PySide are not.
+
+ This fix uses the file URL and use an subprocess with an
+ AppleScript. The script convert the URI into a posix path.
+ The interpreter (osascript) is available on default OS X installs.
+
+ See https://bugreports.qt.io/browse/QTBUG-40449
+ """
+ from ._qt import QUrl
+ import subprocess
+
+ def QUrl_toLocalFile(self):
+ path = QUrl._oldToLocalFile(self)
+ if not path.startswith("/.file/id="):
+ return path
+
+ url = self.toString()
+ script = 'get posix path of my posix file \"%s\" -- kthxbai' % url
+ try:
+ p = subprocess.Popen(["osascript", "-e", script], stdout=subprocess.PIPE)
+ out, _err = p.communicate()
+ if p.returncode == 0:
+ return out.strip()
+ except OSError:
+ pass
+ return path
+
+ QUrl._oldToLocalFile = QUrl.toLocalFile
+ QUrl.toLocalFile = QUrl_toLocalFile
diff --git a/silx/gui/qt/_pyside_dynamic.py b/silx/gui/qt/_pyside_dynamic.py
new file mode 100644
index 0000000..a9246b9
--- /dev/null
+++ b/silx/gui/qt/_pyside_dynamic.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+
+# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
+
+# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com>
+# Modifications by Charl Botha <cpbotha@vxlabs.com>
+# * customWidgets support (registerCustomWidget() causes segfault in
+# pyside 1.1.2 on Ubuntu 12.04 x86_64)
+# * workingDirectory support in loadUi
+
+# found this here:
+# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py
+
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""
+ How to load a user interface dynamically with PySide.
+
+ .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com>
+"""
+
+from __future__ import (print_function, division, unicode_literals,
+ absolute_import)
+
+import logging
+
+from PySide.QtCore import QMetaObject
+from PySide.QtUiTools import QUiLoader
+from PySide.QtGui import QMainWindow
+
+
+_logger = logging.getLogger(__name__)
+
+
+class UiLoader(QUiLoader):
+ """
+ Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface
+ in a base instance.
+
+ Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not
+ create a new instance of the top-level widget, but creates the user
+ interface in an existing instance of the top-level class.
+
+ This mimics the behaviour of :func:`PyQt4.uic.loadUi`.
+ """
+
+ def __init__(self, baseinstance, customWidgets=None):
+ """
+ Create a loader for the given ``baseinstance``.
+
+ The user interface is created in ``baseinstance``, which must be an
+ instance of the top-level class in the user interface to load, or a
+ subclass thereof.
+
+ ``customWidgets`` is a dictionary mapping from class name to class
+ object for widgets that you've promoted in the Qt Designer
+ interface. Usually, this should be done by calling
+ registerCustomWidget on the QUiLoader, but
+ with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault.
+
+ ``parent`` is the parent object of this loader.
+ """
+
+ QUiLoader.__init__(self, baseinstance)
+ self.baseinstance = baseinstance
+ self.customWidgets = customWidgets
+
+ def createWidget(self, class_name, parent=None, name=''):
+ """
+ Function that is called for each widget defined in ui file,
+ overridden here to populate baseinstance instead.
+ """
+
+ if parent is None and self.baseinstance:
+ # supposed to create the top-level widget, return the base instance
+ # instead
+ return self.baseinstance
+
+ else:
+ if class_name in self.availableWidgets():
+ # create a new widget for child widgets
+ widget = QUiLoader.createWidget(self, class_name, parent, name)
+
+ else:
+ # if not in the list of availableWidgets,
+ # must be a custom widget
+ # this will raise KeyError if the user has not supplied the
+ # relevant class_name in the dictionary, or TypeError, if
+ # customWidgets is None
+ try:
+ widget = self.customWidgets[class_name](parent)
+
+ except (TypeError, KeyError):
+ raise Exception('No custom widget ' + class_name +
+ ' found in customWidgets param of' +
+ 'UiLoader __init__.')
+
+ if self.baseinstance:
+ # set an attribute for the new child widget on the base
+ # instance, just like PyQt4.uic.loadUi does.
+ setattr(self.baseinstance, name, widget)
+
+ # this outputs the various widget names, e.g.
+ # sampleGraphicsView, dockWidget, samplesTableView etc.
+ # print(name)
+
+ return widget
+
+
+def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
+ """
+ Dynamically load a user interface from the given ``uifile``.
+
+ ``uifile`` is a string containing a file name of the UI file to load.
+
+ If ``baseinstance`` is ``None``, the a new instance of the top-level widget
+ will be created. Otherwise, the user interface is created within the given
+ ``baseinstance``. In this case ``baseinstance`` must be an instance of the
+ top-level widget class in the UI file to load, or a subclass thereof. In
+ other words, if you've created a ``QMainWindow`` interface in the designer,
+ ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You
+ cannot load a ``QMainWindow`` UI file with a plain
+ :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
+
+ :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the
+ created user interface, so you can implemented your slots according to its
+ conventions in your widget class.
+
+ Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise
+ return the newly created instance of the user interface.
+ """
+ if package is not None:
+ _logger.warning(
+ "loadUi package parameter not implemented with PySide")
+ if resource_suffix is not None:
+ _logger.warning(
+ "loadUi resource_suffix parameter not implemented with PySide")
+
+ loader = UiLoader(baseinstance)
+ widget = loader.load(uifile)
+ QMetaObject.connectSlotsByName(widget)
+ return widget
diff --git a/silx/gui/qt/_pyside_missing.py b/silx/gui/qt/_pyside_missing.py
new file mode 100644
index 0000000..a7e2781
--- /dev/null
+++ b/silx/gui/qt/_pyside_missing.py
@@ -0,0 +1,274 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Python implementation of classes which are not provided by default by PySide.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2017"
+
+
+from PySide.QtGui import QAbstractProxyModel
+from PySide.QtCore import QModelIndex
+from PySide.QtCore import Qt
+from PySide.QtGui import QItemSelection
+from PySide.QtGui import QItemSelectionRange
+
+
+class QIdentityProxyModel(QAbstractProxyModel):
+ """Python translation of the source code of Qt c++ file"""
+
+ def __init__(self, parent=None):
+ super(QIdentityProxyModel, self).__init__(parent)
+ self.__ignoreNextLayoutAboutToBeChanged = False
+ self.__ignoreNextLayoutChanged = False
+ self.__persistentIndexes = []
+
+ def columnCount(self, parent):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().columnCount(parent)
+
+ def dropMimeData(self, data, action, row, column, parent):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().dropMimeData(data, action, row, column, parent)
+
+ def index(self, row, column, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ i = self.sourceModel().index(row, column, parent)
+ return self.mapFromSource(i)
+
+ def insertColumns(self, column, count, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().insertColumns(column, count, parent)
+
+ def insertRows(self, row, count, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().insertRows(row, count, parent)
+
+ def mapFromSource(self, sourceIndex):
+ if self.sourceModel() is None or not sourceIndex.isValid():
+ return QModelIndex()
+ index = self.createIndex(sourceIndex.row(), sourceIndex.column(), sourceIndex.internalPointer())
+ return index
+
+ def mapSelectionFromSource(self, sourceSelection):
+ proxySelection = QItemSelection()
+ if self.sourceModel() is None:
+ return proxySelection
+
+ cursor = sourceSelection.constBegin()
+ end = sourceSelection.constEnd()
+ while cursor != end:
+ topLeft = self.mapFromSource(cursor.topLeft())
+ bottomRight = self.mapFromSource(cursor.bottomRight())
+ proxyRange = QItemSelectionRange(topLeft, bottomRight)
+ proxySelection.append(proxyRange)
+ cursor += 1
+ return proxySelection
+
+ def mapSelectionToSource(self, proxySelection):
+ sourceSelection = QItemSelection()
+ if self.sourceModel() is None:
+ return sourceSelection
+
+ cursor = proxySelection.constBegin()
+ end = proxySelection.constEnd()
+ while cursor != end:
+ topLeft = self.mapToSource(cursor.topLeft())
+ bottomRight = self.mapToSource(cursor.bottomRight())
+ sourceRange = QItemSelectionRange(topLeft, bottomRight)
+ sourceSelection.append(sourceRange)
+ cursor += 1
+ return sourceSelection
+
+ def mapToSource(self, proxyIndex):
+ if self.sourceModel() is None or not proxyIndex.isValid():
+ return QModelIndex()
+ return self.sourceModel().createIndex(proxyIndex.row(), proxyIndex.column(), proxyIndex.internalPointer())
+
+ def match(self, start, role, value, hits=1, flags=Qt.MatchFlags(Qt.MatchStartsWith | Qt.MatchWrap)):
+ if self.sourceModel() is None:
+ return []
+
+ start = self.mapToSource(start)
+ sourceList = self.sourceModel().match(start, role, value, hits, flags)
+ proxyList = []
+ for cursor in sourceList:
+ proxyList.append(self.mapFromSource(cursor))
+ return proxyList
+
+ def parent(self, child):
+ sourceIndex = self.mapToSource(child)
+ sourceParent = sourceIndex.parent()
+ index = self.mapFromSource(sourceParent)
+ return index
+
+ def removeColumns(self, column, count, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().removeColumns(column, count, parent)
+
+ def removeRows(self, row, count, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().removeRows(row, count, parent)
+
+ def rowCount(self, parent=QModelIndex()):
+ parent = self.mapToSource(parent)
+ return self.sourceModel().rowCount(parent)
+
+ def setSourceModel(self, newSourceModel):
+ """Bind and unbind the source model events"""
+ self.beginResetModel()
+
+ sourceModel = self.sourceModel()
+ if sourceModel is not None:
+ sourceModel.rowsAboutToBeInserted.disconnect(self.__rowsAboutToBeInserted)
+ sourceModel.rowsInserted.disconnect(self.__rowsInserted)
+ sourceModel.rowsAboutToBeRemoved.disconnect(self.__rowsAboutToBeRemoved)
+ sourceModel.rowsRemoved.disconnect(self.__rowsRemoved)
+ sourceModel.rowsAboutToBeMoved.disconnect(self.__rowsAboutToBeMoved)
+ sourceModel.rowsMoved.disconnect(self.__rowsMoved)
+ sourceModel.columnsAboutToBeInserted.disconnect(self.__columnsAboutToBeInserted)
+ sourceModel.columnsInserted.disconnect(self.__columnsInserted)
+ sourceModel.columnsAboutToBeRemoved.disconnect(self.__columnsAboutToBeRemoved)
+ sourceModel.columnsRemoved.disconnect(self.__columnsRemoved)
+ sourceModel.columnsAboutToBeMoved.disconnect(self.__columnsAboutToBeMoved)
+ sourceModel.columnsMoved.disconnect(self.__columnsMoved)
+ sourceModel.modelAboutToBeReset.disconnect(self.__modelAboutToBeReset)
+ sourceModel.modelReset.disconnect(self.__modelReset)
+ sourceModel.dataChanged.disconnect(self.__dataChanged)
+ sourceModel.headerDataChanged.disconnect(self.__headerDataChanged)
+ sourceModel.layoutAboutToBeChanged.disconnect(self.__layoutAboutToBeChanged)
+ sourceModel.layoutChanged.disconnect(self.__layoutChanged)
+
+ super(QIdentityProxyModel, self).setSourceModel(newSourceModel)
+
+ sourceModel = self.sourceModel()
+ if sourceModel is not None:
+ sourceModel.rowsAboutToBeInserted.connect(self.__rowsAboutToBeInserted)
+ sourceModel.rowsInserted.connect(self.__rowsInserted)
+ sourceModel.rowsAboutToBeRemoved.connect(self.__rowsAboutToBeRemoved)
+ sourceModel.rowsRemoved.connect(self.__rowsRemoved)
+ sourceModel.rowsAboutToBeMoved.connect(self.__rowsAboutToBeMoved)
+ sourceModel.rowsMoved.connect(self.__rowsMoved)
+ sourceModel.columnsAboutToBeInserted.connect(self.__columnsAboutToBeInserted)
+ sourceModel.columnsInserted.connect(self.__columnsInserted)
+ sourceModel.columnsAboutToBeRemoved.connect(self.__columnsAboutToBeRemoved)
+ sourceModel.columnsRemoved.connect(self.__columnsRemoved)
+ sourceModel.columnsAboutToBeMoved.connect(self.__columnsAboutToBeMoved)
+ sourceModel.columnsMoved.connect(self.__columnsMoved)
+ sourceModel.modelAboutToBeReset.connect(self.__modelAboutToBeReset)
+ sourceModel.modelReset.connect(self.__modelReset)
+ sourceModel.dataChanged.connect(self.__dataChanged)
+ sourceModel.headerDataChanged.connect(self.__headerDataChanged)
+ sourceModel.layoutAboutToBeChanged.connect(self.__layoutAboutToBeChanged)
+ sourceModel.layoutChanged.connect(self.__layoutChanged)
+
+ self.endResetModel()
+
+ def __columnsAboutToBeInserted(self, parent, start, end):
+ parent = self.mapFromSource(parent)
+ self.beginInsertColumns(parent, start, end)
+
+ def __columnsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
+ sourceParent = self.mapFromSource(sourceParent)
+ destParent = self.mapFromSource(destParent)
+ self.beginMoveColumns(sourceParent, sourceStart, sourceEnd, destParent, dest)
+
+ def __columnsAboutToBeRemoved(self, parent, start, end):
+ parent = self.mapFromSource(parent)
+ self.beginRemoveColumns(parent, start, end)
+
+ def __columnsInserted(self, parent, start, end):
+ self.endInsertColumns()
+
+ def __columnsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
+ self.endMoveColumns()
+
+ def __columnsRemoved(self, parent, start, end):
+ self.endRemoveColumns()
+
+ def __dataChanged(self, topLeft, bottomRight):
+ topLeft = self.mapFromSource(topLeft)
+ bottomRight = self.mapFromSource(bottomRight)
+ self.dataChanged(topLeft, bottomRight)
+
+ def __headerDataChanged(self, orientation, first, last):
+ self.headerDataChanged(orientation, first, last)
+
+ def __layoutAboutToBeChanged(self):
+ """Store persistent indexes"""
+ if self.__ignoreNextLayoutAboutToBeChanged:
+ return
+
+ for proxyPersistentIndex in self.persistentIndexList():
+ self.__proxyIndexes.append()
+ sourcePersistentIndex = self.mapToSource(proxyPersistentIndex)
+ mapping = proxyPersistentIndex, sourcePersistentIndex
+ self.__persistentIndexes.append(mapping)
+
+ self.layoutAboutToBeChanged()
+
+ def __layoutChanged(self):
+ """Restore persistent indexes"""
+ if self.__ignoreNextLayoutChanged:
+ return
+
+ for mapping in self.__persistentIndexes:
+ proxyIndex, sourcePersistentIndex = mapping
+ sourcePersistentIndex = self.mapFromSource(sourcePersistentIndex)
+ self.changePersistentIndex(proxyIndex, sourcePersistentIndex)
+
+ self.__persistentIndexes = []
+
+ self.layoutChanged()
+
+ def __modelAboutToBeReset(self):
+ self.beginResetModel()
+
+ def __modelReset(self):
+ self.endResetModel()
+
+ def __rowsAboutToBeInserted(self, parent, start, end):
+ parent = self.mapFromSource(parent)
+ self.beginInsertRows(parent, start, end)
+
+ def __rowsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
+ sourceParent = self.mapFromSource(sourceParent)
+ destParent = self.mapFromSource(destParent)
+ self.beginMoveRows(sourceParent, sourceStart, sourceEnd, destParent, dest)
+
+ def __rowsAboutToBeRemoved(self, parent, start, end):
+ parent = self.mapFromSource(parent)
+ self.beginRemoveRows(parent, start, end)
+
+ def __rowsInserted(self, parent, start, end):
+ self.endInsertRows()
+
+ def __rowsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
+ self.endMoveRows()
+
+ def __rowsRemoved(self, parent, start, end):
+ self.endRemoveRows()
diff --git a/silx/gui/qt/_qt.py b/silx/gui/qt/_qt.py
new file mode 100644
index 0000000..0962c21
--- /dev/null
+++ b/silx/gui/qt/_qt.py
@@ -0,0 +1,229 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Common wrapper over Python Qt bindings:
+
+- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_,
+- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_ or
+- `PySide <http://www.pyside.org>`_.
+
+If a Qt binding is already loaded, it will use it, otherwise the different
+Qt bindings are tried in this order: PyQt4, PySide, PyQt5.
+
+The name of the loaded Qt binding is stored in the BINDING variable.
+
+For an alternative solution providing a structured namespace,
+see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which
+provides the namespace of PyQt5 over PyQt4 and PySide.
+"""
+
+__authors__ = ["V.A. Sole - ESRF Data Analysis"]
+__license__ = "MIT"
+__date__ = "17/01/2017"
+
+
+import logging
+import sys
+import traceback
+
+
+_logger = logging.getLogger(__name__)
+
+
+BINDING = None
+"""The name of the Qt binding in use: 'PyQt5', 'PyQt4' or 'PySide'."""
+
+QtBinding = None # noqa
+"""The Qt binding module in use: PyQt5, PyQt4 or PySide."""
+
+HAS_SVG = False
+"""True if Qt provides support for Scalable Vector Graphics (QtSVG)."""
+
+HAS_OPENGL = False
+"""True if Qt provides support for OpenGL (QtOpenGL)."""
+
+# First check for an already loaded wrapper
+if 'PySide.QtCore' in sys.modules:
+ BINDING = 'PySide'
+
+elif 'PyQt5.QtCore' in sys.modules:
+ BINDING = 'PyQt5'
+
+elif 'PyQt4.QtCore' in sys.modules:
+ BINDING = 'PyQt4'
+
+else: # Then try Qt bindings
+ try:
+ import PyQt4 # noqa
+ except ImportError:
+ try:
+ import PySide # noqa
+ except ImportError:
+ try:
+ import PyQt5 # noqa
+ except ImportError:
+ raise ImportError(
+ 'No Qt wrapper found. Install PyQt4, PyQt5 or PySide.')
+ else:
+ BINDING = 'PyQt5'
+ else:
+ BINDING = 'PySide'
+ else:
+ BINDING = 'PyQt4'
+
+
+if BINDING == 'PyQt4':
+ _logger.debug('Using PyQt4 bindings')
+
+ if sys.version < "3.0.0":
+ try:
+ import sip
+
+ sip.setapi("QString", 2)
+ sip.setapi("QVariant", 2)
+ except:
+ _logger.warning("Cannot set sip API")
+
+ import PyQt4 as QtBinding # noqa
+
+ from PyQt4.QtCore import * # noqa
+ from PyQt4.QtGui import * # noqa
+
+ try:
+ from PyQt4.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PyQt4.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PyQt4.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PyQt4.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ from PyQt4.uic import loadUi # noqa
+
+ Signal = pyqtSignal
+
+ Property = pyqtProperty
+
+ Slot = pyqtSlot
+
+elif BINDING == 'PySide':
+ _logger.debug('Using PySide bindings')
+
+ import PySide as QtBinding # noqa
+
+ from PySide.QtCore import * # noqa
+ from PySide.QtGui import * # noqa
+
+ try:
+ from PySide.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PySide.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PySide.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PySide.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ pyqtSignal = Signal
+
+ # Import loadUi wrapper for PySide
+ from ._pyside_dynamic import loadUi # noqa
+
+ # Import missing classes
+ if not hasattr(locals(), "QIdentityProxyModel"):
+ from ._pyside_missing import QIdentityProxyModel # noqa
+
+elif BINDING == 'PyQt5':
+ _logger.debug('Using PyQt5 bindings')
+
+ import PyQt5 as QtBinding # noqa
+
+ from PyQt5.QtCore import * # noqa
+ from PyQt5.QtGui import * # noqa
+ from PyQt5.QtWidgets import * # noqa
+ from PyQt5.QtPrintSupport import * # noqa
+
+ try:
+ from PyQt5.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PySide.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PyQt5.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PyQt5.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ from PyQt5.uic import loadUi # noqa
+
+ Signal = pyqtSignal
+
+ Property = pyqtProperty
+
+ Slot = pyqtSlot
+
+else:
+ raise ImportError('No Qt wrapper found. Install PyQt4, PyQt5 or PySide')
+
+# provide a exception handler but not implement it by default
+def exceptionHandler(type_, value, trace):
+ """
+ This exception handler prevents quitting to the command line when there is
+ an unhandled exception while processing a Qt signal.
+
+ The script/application willing to use it should implement code similar to:
+
+ .. code-block:: python
+
+ if __name__ == "__main__":
+ sys.excepthook = qt.exceptionHandler
+
+ """
+ _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace)))
+ msg = QMessageBox()
+ msg.setWindowTitle("Unhandled exception")
+ msg.setIcon(QMessageBox.Critical)
+ msg.setInformativeText("%s %s\nPlease report details" % (type_, value))
+ msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace)))
+ msg.raise_()
+ msg.exec_()
+
diff --git a/silx/gui/qt/_utils.py b/silx/gui/qt/_utils.py
new file mode 100644
index 0000000..0aa3ef1
--- /dev/null
+++ b/silx/gui/qt/_utils.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides convenient functions related to Qt.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+import sys
+from ._qt import BINDING, QImageReader
+
+
+def supportedImageFormats():
+ """Return a set of string of file format extensions supported by the
+ Qt runtime."""
+ if sys.version_info[0] < 3 or BINDING == 'PySide':
+ convert = str
+ else:
+ convert = lambda data: str(data, 'ascii')
+ formats = QImageReader.supportedImageFormats()
+ return set([convert(data) for data in formats])
diff --git a/silx/gui/setup.py b/silx/gui/setup.py
new file mode 100644
index 0000000..fbe9058
--- /dev/null
+++ b/silx/gui/setup.py
@@ -0,0 +1,51 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('gui', parent_package, top_path)
+ config.add_subpackage('_glutils')
+ config.add_subpackage('qt')
+ config.add_subpackage('plot')
+ config.add_subpackage('fit')
+ config.add_subpackage('hdf5')
+ config.add_subpackage('widgets')
+ config.add_subpackage('test')
+ config.add_subpackage('plot3d')
+ config.add_subpackage('data')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/gui/test/__init__.py b/silx/gui/test/__init__.py
new file mode 100644
index 0000000..7449860
--- /dev/null
+++ b/silx/gui/test/__init__.py
@@ -0,0 +1,108 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/01/2017"
+
+
+import logging
+import os
+import sys
+import unittest
+
+
+_logger = logging.getLogger(__name__)
+
+
+def suite():
+
+ test_suite = unittest.TestSuite()
+
+ if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ # On Linux and no DISPLAY available (e.g., ssh without -X)
+ _logger.warning('silx.gui tests disabled (DISPLAY env. variable not set)')
+
+ class SkipGUITest(unittest.TestCase):
+ def runTest(self):
+ self.skipTest(
+ 'silx.gui tests disabled (DISPLAY env. variable not set)')
+
+ test_suite.addTest(SkipGUITest())
+ return test_suite
+
+ elif os.environ.get('WITH_QT_TEST', 'True') == 'False':
+ # Explicitly disabled tests
+ _logger.warning(
+ "silx.gui tests disabled (env. variable WITH_QT_TEST=False)")
+
+ class SkipGUITest(unittest.TestCase):
+ def runTest(self):
+ self.skipTest(
+ "silx.gui tests disabled (env. variable WITH_QT_TEST=False)")
+
+ test_suite.addTest(SkipGUITest())
+ return test_suite
+
+ # Import here to avoid loading QT if tests are disabled
+
+ from ..plot import test as test_plot
+ from ..fit import test as test_fit
+ from ..hdf5 import test as test_hdf5
+ from ..widgets import test as test_widgets
+ from ..data import test as test_data
+ from . import test_qt
+ # Console tests disabled due to corruption of python environment
+ # (see issue #538 on github)
+ # from . import test_console
+ from . import test_icons
+ from . import test_utils
+
+ try:
+ from ..plot3d.test import suite as test_plot3d_suite
+
+ except ImportError:
+ _logger.warning(
+ 'silx.gui.plot3d tests disabled '
+ '(PyOpenGL or QtOpenGL not installed)')
+
+ class SkipPlot3DTest(unittest.TestCase):
+ def runTest(self):
+ self.skipTest('silx.gui.plot3d tests disabled '
+ '(PyOpenGL or QtOpenGL not installed)')
+
+ test_plot3d_suite = SkipPlot3DTest
+
+
+ test_suite.addTest(test_qt.suite())
+ test_suite.addTest(test_plot.suite())
+ test_suite.addTest(test_fit.suite())
+ test_suite.addTest(test_hdf5.suite())
+ test_suite.addTest(test_widgets.suite())
+ # test_suite.addTest(test_console.suite()) # see issue #538 on github
+ test_suite.addTest(test_icons.suite())
+ test_suite.addTest(test_data.suite())
+ test_suite.addTest(test_utils.suite())
+ test_suite.addTest(test_plot3d_suite())
+ return test_suite
diff --git a/silx/gui/test/test_console.py b/silx/gui/test/test_console.py
new file mode 100644
index 0000000..7c25372
--- /dev/null
+++ b/silx/gui/test/test_console.py
@@ -0,0 +1,91 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for IPython console widget"""
+
+from __future__ import print_function
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui import qt
+try:
+ from silx.gui.console import IPythonDockWidget
+except ImportError:
+ console_missing = True
+else:
+ console_missing = False
+
+
+# dummy objects to test pushing variables to the interactive namespace
+_a = 1
+
+
+def _f():
+ print("Hello World!")
+
+
+@unittest.skipIf(console_missing, "Could not import Ipython and/or qtconsole")
+class TestConsole(TestCaseQt):
+ """Basic test for ``module.IPythonDockWidget``"""
+
+ def setUp(self):
+ super(TestConsole, self).setUp()
+ self.console = IPythonDockWidget(
+ available_vars={"a": _a, "f": _f},
+ custom_banner="Welcome!\n")
+ self.console.show()
+ self.qWaitForWindowExposed(self.console)
+
+ def tearDown(self):
+ self.console.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.console.close()
+ del self.console
+ super(TestConsole, self).tearDown()
+
+ def testShow(self):
+ pass
+
+ def testInteract(self):
+ self.mouseClick(self.console, qt.Qt.LeftButton)
+ self.keyClicks(self.console, 'import silx')
+ self.keyClick(self.console, qt.Qt.Key_Enter)
+ self.qapp.processEvents()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestConsole))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_icons.py b/silx/gui/test/test_icons.py
new file mode 100644
index 0000000..f363c43
--- /dev/null
+++ b/silx/gui/test/test_icons.py
@@ -0,0 +1,116 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt icons module."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+
+import gc
+import unittest
+import weakref
+
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+from silx.gui import icons
+
+
+class TestIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ def testSvgIcon(self):
+ if "svg" not in qt.supportedImageFormats():
+ self.skipTest("SVG not supported")
+ icon = icons.getQIcon("test-svg")
+ self.assertIsNotNone(icon)
+
+ def testPngIcon(self):
+ icon = icons.getQIcon("test-png")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingIcon(self):
+ self.assertRaises(ValueError, icons.getQIcon, "not-exists")
+
+ def testExistingQPixmap(self):
+ icon = icons.getQPixmap("crop")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingQPixmap(self):
+ self.assertRaises(ValueError, icons.getQPixmap, "not-exists")
+
+ def testCache(self):
+ icon1 = icons.getQIcon("crop")
+ icon2 = icons.getQIcon("crop")
+ self.assertIs(icon1, icon2)
+
+ def testCacheReleased(self):
+ icon = icons.getQIcon("crop")
+ icon_ref = weakref.ref(icon)
+ del icon
+ gc.collect()
+ self.assertIsNone(icon_ref())
+
+
+class TestAnimatedIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ def testProcessWorking(self):
+ icon = icons.getWaitIcon()
+ self.assertIsNotNone(icon)
+
+ def testProcessWorkingCache(self):
+ icon1 = icons.getWaitIcon()
+ icon2 = icons.getWaitIcon()
+ self.assertIs(icon1, icon2)
+
+ def testMovieIconExists(self):
+ if "mng" not in qt.supportedImageFormats():
+ self.skipTest("MNG not supported")
+ icon = icons.MovieAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testMovieIconNotExists(self):
+ self.assertRaises(ValueError, icons.MovieAnimatedIcon, "not-exists")
+
+ def testMultiImageIconExists(self):
+ icon = icons.MultiImageAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testMultiImageIconNotExists(self):
+ self.assertRaises(ValueError, icons.MultiImageAnimatedIcon, "not-exists")
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestIcons))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestAnimatedIcons))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_qt.py b/silx/gui/test/test_qt.py
new file mode 100644
index 0000000..3a89a33
--- /dev/null
+++ b/silx/gui/test/test_qt.py
@@ -0,0 +1,144 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt bindings wrapper."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import os.path
+import unittest
+
+from silx.test.utils import temp_dir
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui import qt
+
+
+class TestQtWrapper(unittest.TestCase):
+ """Minimalistic test to check that Qt has been loaded."""
+
+ def testQObject(self):
+ """Test that QObject is there."""
+ obj = qt.QObject()
+ self.assertTrue(obj is not None)
+
+
+class TestLoadUi(TestCaseQt):
+ """Test loadUi function"""
+
+ TEST_UI = """<?xml version="1.0" encoding="UTF-8"?>
+ <ui version="4.0">
+ <class>MainWindow</class>
+ <widget class="QMainWindow" name="MainWindow">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>296</height>
+ </rect>
+ </property>
+ <property name="windowTitle">
+ <string>Test loadUi</string>
+ </property>
+ <widget class="QWidget" name="centralwidget">
+ <widget class="QPushButton" name="pushButton">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>10</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 1</string>
+ </property>
+ </widget>
+ <widget class="QPushButton" name="pushButton_2">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>50</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 2</string>
+ </property>
+ </widget>
+ </widget>
+ <widget class="QMenuBar" name="menubar">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>25</height>
+ </rect>
+ </property>
+ </widget>
+ <widget class="QStatusBar" name="statusbar"/>
+ </widget>
+ <resources/>
+ <connections/>
+ </ui>
+ """
+
+ def testLoadUi(self):
+ """Create a QMainWindow from an ui file"""
+ with temp_dir() as tmp:
+ uifile = os.path.join(tmp, "test.ui")
+
+ # write file
+ with open(uifile, mode='w') as f:
+ f.write(self.TEST_UI)
+
+ class TestMainWindow(qt.QMainWindow):
+ def __init__(self, parent=None):
+ super(TestMainWindow, self).__init__(parent)
+ qt.loadUi(uifile, self)
+
+ testMainWindow = TestMainWindow()
+ testMainWindow.show()
+ self.qWaitForWindowExposed(testMainWindow)
+
+ testMainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
+ testMainWindow.close()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestCaseCls in (TestQtWrapper, TestLoadUi):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestCaseCls))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_utils.py b/silx/gui/test/test_utils.py
new file mode 100644
index 0000000..4625969
--- /dev/null
+++ b/silx/gui/test/test_utils.py
@@ -0,0 +1,77 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of utils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+
+from silx.gui import _utils
+
+
+class TestQImageConversion(TestCaseQt):
+ """Tests conversion of QImage to/from numpy array."""
+
+ def testConvertArrayToQImage(self):
+ """Test conversion of numpy array to QImage"""
+ image = numpy.ones((3, 3, 3), dtype=numpy.uint8)
+ qimage = _utils.convertArrayToQImage(image)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(qimage.format(), qt.QImage.Format_RGB888)
+
+ color = qt.QColor(1, 1, 1).rgb()
+ self.assertEqual(qimage.pixel(1, 1), color)
+
+ def testConvertQImageToArray(self):
+ """Test conversion of QImage to numpy array"""
+ qimage = qt.QImage(3, 3, qt.QImage.Format_RGB888)
+ qimage.fill(0x010101)
+ image = _utils.convertQImageToArray(qimage)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(image.shape[2], 3)
+ self.assertTrue(numpy.all(numpy.equal(image, 1)))
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
+ TestQImageConversion))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/utils.py b/silx/gui/test/utils.py
new file mode 100644
index 0000000..50cf7bf
--- /dev/null
+++ b/silx/gui/test/utils.py
@@ -0,0 +1,428 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Helper class to write Qt widget unittests."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/04/2017"
+
+
+import gc
+import logging
+import unittest
+import time
+import functools
+import sys
+
+logging.basicConfig()
+_logger = logging.getLogger(__name__)
+
+from silx.gui import qt
+
+if qt.BINDING == 'PySide':
+ from PySide.QtTest import QTest
+elif qt.BINDING == 'PyQt5':
+ from PyQt5.QtTest import QTest
+elif qt.BINDING == 'PyQt4':
+ from PyQt4.QtTest import QTest
+else:
+ raise ImportError('Unsupported Qt bindings')
+
+# Qt4/Qt5 compatibility wrapper
+if qt.BINDING in ('PySide', 'PyQt4'):
+ _logger.info("QTest.qWaitForWindowExposed not available," +
+ "using QTest.qWaitForWindowShown instead.")
+
+ def qWaitForWindowExposed(window, timeout=None):
+ """Mimic QTest.qWaitForWindowExposed for Qt4."""
+ QTest.qWaitForWindowShown(window)
+ return True
+else:
+ qWaitForWindowExposed = QTest.qWaitForWindowExposed
+
+
+def qWaitForWindowExposedAndActivate(window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ It also activates the window and raises it.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ if timeout is None:
+ result = qWaitForWindowExposed(window)
+ else:
+ result = qWaitForWindowExposed(window, timeout)
+
+ if result:
+ # Makes sure window is active and on top
+ window.activateWindow()
+ window.raise_()
+
+ return result
+
+
+# Placeholder for QApplication
+_qapp = None
+
+
+class TestCaseQt(unittest.TestCase):
+ """Base class to write test for Qt stuff.
+
+ It creates a QApplication before running the tests.
+ WARNING: The QApplication is shared by all tests, which might have side
+ effects.
+
+ After each test, this class is checking for widgets remaining alive.
+ To allow some widgets to remain alive at the end of a test, set the
+ allowedLeakingWidgets attribute to the number of widgets that can remain
+ alive at the end of the test.
+ With PySide, this test is not run for now as it seems PySide
+ is leaking widgets internally.
+
+ All keyboard and mouse event simulation methods call qWait(20) after
+ simulating the event (as QTest does on Mac OSX).
+ This was introduced to fix issues with continuous integration tests
+ running with Xvfb on Linux.
+ """
+
+ DEFAULT_TIMEOUT_WAIT = 100
+ """Default timeout for qWait"""
+
+ TIMEOUT_WAIT = 0
+ """Extra timeout in millisecond to add to qSleep, qWait and
+ qWaitForWindowExposed.
+
+ Intended purpose is for debugging, to add extra time to waits in order to
+ allow to view the tested widgets.
+ """
+
+ @classmethod
+ def exceptionHandler(cls, exceptionClass, exception, stack):
+ import traceback
+ message = (''.join(traceback.format_tb(stack)))
+ template = 'Traceback (most recent call last):\n{2}{0}: {1}'
+ message = template.format(exceptionClass.__name__, exception, message)
+ cls._exceptions.append(message)
+
+ @classmethod
+ def setUpClass(cls):
+ """Makes sure Qt is inited"""
+ cls._oldExceptionHook = sys.excepthook
+ sys.excepthook = cls.exceptionHandler
+
+ global _qapp
+ if _qapp is None:
+ # Makes sure a QApplication exists and do it once for all
+ _qapp = qt.QApplication.instance() or qt.QApplication([])
+
+ # Create/delate a QWidget to make sure init of QDesktopWidget
+ _dummyWidget = qt.QWidget()
+ _dummyWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ _dummyWidget.show()
+ _dummyWidget.close()
+ _qapp.processEvents()
+
+ @classmethod
+ def tearDownClass(cls):
+ sys.excepthook = cls._oldExceptionHook
+
+ def setUp(self):
+ """Get the list of existing widgets."""
+ self.allowedLeakingWidgets = 0
+ self.__previousWidgets = self.qapp.allWidgets()
+ self.__class__._exceptions = []
+
+ def _currentTestSucceeded(self):
+ if hasattr(self, '_outcome'):
+ # For Python >= 3.4
+ result = self.defaultTestResult() # these 2 methods have no side effects
+ self._feedErrorsToResult(result, self._outcome.errors)
+ else:
+ # For Python < 3.4
+ result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups)
+
+ error = self.id() in [case.id() for case, _ in result.errors]
+ failure = self.id() in [case.id() for case, _ in result.failures]
+ return not error and not failure
+
+ def _checkForUnreleasedWidgets(self):
+ """Test fixture checking that no more widgets exists."""
+ gc.collect()
+
+ widgets = [widget for widget in self.qapp.allWidgets()
+ if widget not in self.__previousWidgets]
+ del self.__previousWidgets
+
+ if qt.BINDING == 'PySide':
+ return # Do not test for leaking widgets with PySide
+
+ allowedLeakingWidgets = self.allowedLeakingWidgets
+ self.allowedLeakingWidgets = 0
+
+ if widgets and len(widgets) <= allowedLeakingWidgets:
+ _logger.info(
+ '%s: %d remaining widgets after test' % (self.id(),
+ len(widgets)))
+
+ if len(widgets) > allowedLeakingWidgets:
+ raise RuntimeError(
+ "Test ended with widgets alive: %s" % str(widgets))
+
+ def tearDown(self):
+ if len(self.__class__._exceptions) > 0:
+ messages = "\n".join(self.__class__._exceptions)
+ raise AssertionError("Exception occured in Qt thread:\n" + messages)
+
+ if self._currentTestSucceeded():
+ self._checkForUnreleasedWidgets()
+
+ @property
+ def qapp(self):
+ """The QApplication currently running."""
+ return qt.QApplication.instance()
+
+ # Proxy to QTest
+
+ Press = QTest.Press
+ """Key press action code"""
+
+ Release = QTest.Release
+ """Key release action code"""
+
+ Click = QTest.Click
+ """Key click action code"""
+
+ QTest = property(lambda self: QTest,
+ doc="""The Qt QTest class from the used Qt binding.""")
+
+ def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a key.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClick(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyClicks(self, widget, sequence, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a sequence of keys.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClicks(widget, sequence, modifier, delay)
+ self.qWait(20)
+
+ def keyEvent(self, action, widget, key,
+ modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key event.
+
+ See QTest.keyEvent for details.
+ """
+ QTest.keyEvent(action, widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyPress(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key press event.
+
+ See QTest.keyPress for details.
+ """
+ QTest.keyPress(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyRelease(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key release event.
+
+ See QTest.keyRelease for details.
+ """
+ QTest.keyRelease(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def mouseClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate clicking a mouse button.
+
+ See QTest.mouseClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint()
+ QTest.mouseClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseDClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate double clicking a mouse button.
+
+ See QTest.mouseDClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint()
+ QTest.mouseDClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseMove(self, widget, pos=None, delay=-1):
+ """Simulate moving the mouse.
+
+ See QTest.mouseMove for details.
+ """
+ pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint()
+ QTest.mouseMove(widget, pos, delay)
+ self.qWait(20)
+
+ def mousePress(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate pressing a mouse button.
+
+ See QTest.mousePress for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint()
+ QTest.mousePress(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseRelease(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate releasing a mouse button.
+
+ See QTest.mouseRelease for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(pos[0], pos[1]) if pos is not None else qt.QPoint()
+ QTest.mouseRelease(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def qSleep(self, ms):
+ """Sleep for ms milliseconds, blocking the execution of the test.
+
+ See QTest.qSleep for details.
+ """
+ QTest.qSleep(ms + self.TIMEOUT_WAIT)
+
+ def qWait(self, ms=None):
+ """Waits for ms milliseconds, events will be processed.
+
+ See QTest.qWait for details.
+ """
+ if ms is None:
+ ms = self.DEFAULT_TIMEOUT_WAIT
+
+ if qt.BINDING == 'PySide':
+ # PySide has no qWait, provide a replacement
+ timeout = int(ms)
+ endTimeMS = int(time.time() * 1000) + timeout
+ while timeout > 0:
+ self.qapp.processEvents(qt.QEventLoop.AllEvents,
+ maxtime=timeout)
+ timeout = endTimeMS - int(time.time() * 1000)
+ else:
+ QTest.qWait(ms + self.TIMEOUT_WAIT)
+
+ def qWaitForWindowExposed(self, window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ result = qWaitForWindowExposedAndActivate(window, timeout)
+
+ if self.TIMEOUT_WAIT:
+ QTest.qWait(self.TIMEOUT_WAIT)
+
+ return result
+
+
+class SignalListener():
+ """Util to listen a Qt event and store parameters
+ """
+
+ def __init__(self):
+ self.__calls = []
+
+ def __call__(self, *args, **kargs):
+ self.__calls.append((args, kargs))
+
+ def clear(self):
+ """Clear stored data"""
+ self.__calls = []
+
+ def callCount(self):
+ """
+ Returns how many times the listener was called.
+
+ :rtype: int
+ """
+ return len(self.__calls)
+
+ def arguments(self, callIndex=None, argumentIndex=None):
+ """Returns positional arguments optionally filtered by call count id
+ or argument index.
+
+ :param int callIndex: Index of the called data
+ :param int argumentIndex: Index of the positional argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][0]
+ if argumentIndex is not None:
+ result = result[argumentIndex]
+ else:
+ result = [x[0] for x in self.__calls]
+ if argumentIndex is not None:
+ result = [x[argumentIndex] for x in result]
+ return result
+
+ def karguments(self, callIndex=None, argumentName=None):
+ """Returns positional arguments optionally filtered by call count id
+ or name of the keyword argument.
+
+ :param int callIndex: Index of the called data
+ :param int argumentName: Name of the keyword argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][1]
+ if argumentName is not None:
+ result = result[argumentName]
+ else:
+ result = [x[1] for x in self.__calls]
+ if argumentName is not None:
+ result = [x[argumentName] for x in result]
+ return result
+
+ def partial(self, *args, **kargs):
+ """Returns a new partial object which when called will behave like this
+ listener called with the positional arguments args and keyword
+ arguments keywords. If more arguments are supplied to the call, they
+ are appended to args. If additional keyword arguments are supplied,
+ they extend and override keywords.
+ """
+ return functools.partial(self, *args, **kargs)
+
+
+def getQToolButtonFromAction(action):
+ """Return a QToolButton corresponding to a QAction.
+
+ :param QAction action: The QAction from which to get QToolButton.
+ :return: A QToolButton associated to action or None.
+ """
+ for widget in action.associatedWidgets():
+ if isinstance(widget, qt.QToolButton):
+ return widget
+ return None
diff --git a/silx/gui/widgets/FrameBrowser.py b/silx/gui/widgets/FrameBrowser.py
new file mode 100644
index 0000000..783a70a
--- /dev/null
+++ b/silx/gui/widgets/FrameBrowser.py
@@ -0,0 +1,307 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines two main classes:
+
+ - :class:`FrameBrowser`: a widget with 4 buttons (first, previous, next,
+ last) to browse between frames and a text entry to access a specific frame
+ by typing it's number)
+ - :class:`HorizontalSliderWithBrowser`: a FrameBrowser with an additional
+ slider. This class inherits :class:`qt.QAbstractSlider`.
+
+"""
+from silx.gui import qt
+from silx.gui import icons
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+class FrameBrowser(qt.QWidget):
+ """Frame browser widget, with 4 buttons/icons and a line edit to provide
+ a way of selecting a frame index in a stack of images.
+
+ It can be used in more generic case to select an integer within a range.
+
+ :param QWidget parent: Parent widget
+ :param int n: Number of frames. This will set the range
+ of frame indices to 0--n-1.
+ If None, the range is initialized to the default QSlider range (0--99)."""
+ sigIndexChanged = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None, n=None):
+ qt.QWidget.__init__(self, parent)
+
+ # Use the font size as the icon size to avoid to create bigger buttons
+ fontMetric = self.fontMetrics()
+ iconSize = qt.QSize(fontMetric.height(), fontMetric.height())
+
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+ self.firstButton = qt.QPushButton(self)
+ self.firstButton.setIcon(icons.getQIcon("first"))
+ self.firstButton.setIconSize(iconSize)
+ self.previousButton = qt.QPushButton(self)
+ self.previousButton.setIcon(icons.getQIcon("previous"))
+ self.previousButton.setIconSize(iconSize)
+ self._lineEdit = qt.QLineEdit(self)
+
+ self._label = qt.QLabel(self)
+ self.nextButton = qt.QPushButton(self)
+ self.nextButton.setIcon(icons.getQIcon("next"))
+ self.nextButton.setIconSize(iconSize)
+ self.lastButton = qt.QPushButton(self)
+ self.lastButton.setIcon(icons.getQIcon("last"))
+ self.lastButton.setIconSize(iconSize)
+
+ self.mainLayout.addWidget(self.firstButton)
+ self.mainLayout.addWidget(self.previousButton)
+ self.mainLayout.addWidget(self._lineEdit)
+ self.mainLayout.addWidget(self._label)
+ self.mainLayout.addWidget(self.nextButton)
+ self.mainLayout.addWidget(self.lastButton)
+
+ if n is None:
+ first = qt.QSlider().minimum()
+ last = qt.QSlider().maximum()
+ else:
+ first, last = 0, n
+
+ self._lineEdit.setFixedWidth(self._lineEdit.fontMetrics().width('%05d' % last))
+ validator = qt.QIntValidator(first, last, self._lineEdit)
+ self._lineEdit.setValidator(validator)
+ self._lineEdit.setText("%d" % first)
+ self._label.setText("of %d" % last)
+
+ self._index = first
+ """0-based index"""
+
+ self.firstButton.clicked.connect(self._firstClicked)
+ self.previousButton.clicked.connect(self._previousClicked)
+ self.nextButton.clicked.connect(self._nextClicked)
+ self.lastButton.clicked.connect(self._lastClicked)
+ self._lineEdit.editingFinished.connect(self._textChangedSlot)
+
+ def lineEdit(self):
+ """Returns the line edit provided by this widget.
+
+ :rtype: qt.QLineEdit
+ """
+ return self._lineEdit
+
+ def limitWidget(self):
+ """Returns the widget displaying axes limits.
+
+ :rtype: qt.QLabel
+ """
+ return self._label
+
+ def _firstClicked(self):
+ """Select first/lowest frame number"""
+ self._lineEdit.setText("%d" % self._lineEdit.validator().bottom())
+ self._textChangedSlot()
+
+ def _previousClicked(self):
+ """Select previous frame number"""
+ if self._index > self._lineEdit.validator().bottom():
+ self._lineEdit.setText("%d" % (self._index - 1))
+ self._textChangedSlot()
+
+ def _nextClicked(self):
+ """Select next frame number"""
+ if self._index < (self._lineEdit.validator().top()):
+ self._lineEdit.setText("%d" % (self._index + 1))
+ self._textChangedSlot()
+
+ def _lastClicked(self):
+ """Select last/highest frame number"""
+ self._lineEdit.setText("%d" % self._lineEdit.validator().top())
+ self._textChangedSlot()
+
+ def _textChangedSlot(self):
+ """Select frame number typed in the line edit widget"""
+ txt = self._lineEdit.text()
+ if not len(txt):
+ self._lineEdit.setText("%d" % self._index)
+ return
+ new_value = int(txt)
+ if new_value == self._index:
+ return
+ ddict = {
+ "event": "indexChanged",
+ "old": self._index,
+ "new": new_value,
+ "id": id(self)
+ }
+ self._index = new_value
+ self.sigIndexChanged.emit(ddict)
+
+ def setRange(self, first, last):
+ """Set minimum and maximum frame indices
+ Initialize the frame index to *first*.
+ Update the label text to *" limits: first, last"*
+
+ :param int first: Minimum frame index
+ :param int last: Maximum frame index"""
+ return self.setLimits(first, last)
+
+ def setLimits(self, first, last):
+ """Set minimum and maximum frame indices.
+ Initialize the frame index to *first*.
+ Update the label text to *" limits: first, last"*
+
+ :param int first: Minimum frame index
+ :param int last: Maximum frame index"""
+ bottom = min(first, last)
+ top = max(first, last)
+ self._lineEdit.validator().setTop(top)
+ self._lineEdit.validator().setBottom(bottom)
+ self._index = bottom
+ self._lineEdit.setText("%d" % self._index)
+ self._label.setText(" limits: %d, %d " % (bottom, top))
+
+ def setNFrames(self, nframes):
+ """Set minimum=0 and maximum=nframes-1 frame numbers.
+ Initialize the frame index to 0.
+ Update the label text to *"1 of nframes"*
+
+ :param int nframes: Number of frames"""
+ bottom = 0
+ top = nframes - 1
+ self._lineEdit.validator().setTop(top)
+ self._lineEdit.validator().setBottom(bottom)
+ self._index = bottom
+ self._lineEdit.setText("%d" % self._index)
+ # display 1-based index in label
+ self._label.setText(" %d of %d " % (self._index + 1, top + 1))
+
+ def getCurrentIndex(self):
+ """Get 0-based frame index
+ """
+ return self._index
+
+ def setValue(self, value):
+ """Set 0-based frame index
+
+ :param int value: Frame number"""
+ self._lineEdit.setText("%d" % value)
+ self._textChangedSlot()
+
+
+class HorizontalSliderWithBrowser(qt.QAbstractSlider):
+ """
+ Slider widget combining a :class:`QSlider` and a :class:`FrameBrowser`.
+
+ The data model is an integer within a range.
+
+ The default value is the default :class:`QSlider` value (0),
+ and the default range is the default QSlider range (0 -- 99)
+
+ The signal emitted when the value is changed is the usual QAbstractSlider
+ signal :attr:`valueChanged`. The signal carries the value (as an integer).
+
+ :param QWidget parent: Optional parent widget
+ """
+ sigIndexChanged = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None):
+ qt.QAbstractSlider.__init__(self, parent)
+ self.setOrientation(qt.Qt.Horizontal)
+
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+
+ self._slider = qt.QSlider(self)
+ self._slider.setOrientation(qt.Qt.Horizontal)
+
+ self._browser = FrameBrowser(self)
+
+ self.mainLayout.addWidget(self._slider, 1)
+ self.mainLayout.addWidget(self._browser)
+
+ self._slider.valueChanged[int].connect(self._sliderSlot)
+ self._browser.sigIndexChanged.connect(self._browserSlot)
+
+ def lineEdit(self):
+ """Returns the line edit provided by this widget.
+
+ :rtype: qt.QLineEdit
+ """
+ return self._browser.lineEdit()
+
+ def limitWidget(self):
+ """Returns the widget displaying axes limits.
+
+ :rtype: qt.QLabel
+ """
+ return self._browser.limitWidget()
+
+ def setMinimum(self, value):
+ """Set minimum value
+
+ :param int value: Minimum value"""
+ self._slider.setMinimum(value)
+ maximum = self._slider.maximum()
+ self._browser.setRange(value, maximum)
+
+ def setMaximum(self, value):
+ """Set maximum value
+
+ :param int value: Maximum value
+ """
+ self._slider.setMaximum(value)
+ minimum = self._slider.minimum()
+ self._browser.setRange(minimum, value)
+
+ def setRange(self, first, last):
+ """Set minimum/maximum values
+
+ :param int first: Minimum value
+ :param int last: Maximum value"""
+ self._slider.setRange(first, last)
+ self._browser.setRange(first, last)
+
+ def _sliderSlot(self, value):
+ """Emit selected value when slider is activated
+ """
+ self._browser.setValue(value)
+ self.valueChanged.emit(value)
+
+ def _browserSlot(self, ddict):
+ """Emit selected value when browser state is changed"""
+ self._slider.setValue(ddict['new'])
+
+ def setValue(self, value):
+ """Set value
+
+ :param int value: value"""
+ self._slider.setValue(value)
+ self._browser.setValue(value)
+
+ def value(self):
+ """Get selected value"""
+ return self._slider.value()
diff --git a/silx/gui/widgets/HierarchicalTableView.py b/silx/gui/widgets/HierarchicalTableView.py
new file mode 100644
index 0000000..3ccf4c7
--- /dev/null
+++ b/silx/gui/widgets/HierarchicalTableView.py
@@ -0,0 +1,172 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define a hierarchical table view and model.
+
+It allows to define many headers in the middle of a table.
+
+The implementation hide the default header and allows to custom each cells
+to became a header.
+
+Row and column span is a concept of the view in a QTableView.
+This implementation also provide a span property as part of the model of the
+cell. A role is define to custom this information.
+The view is updated everytime the model is reset to take care of the
+changes of this information.
+
+A default item delegate is used to redefine the paint of the cells.
+"""
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+from silx.gui import qt
+
+
+class HierarchicalTableModel(qt.QAbstractTableModel):
+ """
+ Abstract table model to provide more custom on row and column span and
+ headers.
+
+ Default headers are ignored and each cells can define IsHeaderRole and
+ SpanRole using the `data` function.
+ """
+
+ SpanRole = qt.Qt.UserRole + 0
+ """Role returning a tuple for number of row span then column span.
+
+ None and (1, 1) are neutral for the rendering.
+ """
+
+ IsHeaderRole = qt.Qt.UserRole + 1
+ """Role returning True is the identified cell is a header."""
+
+ UserRole = qt.Qt.UserRole + 2
+ """First index of user defined roles"""
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers
+
+ In this case the headers are just ignored. Header information is part
+ of each cells.
+ """
+ return None
+
+
+class HierarchicalItemDelegate(qt.QStyledItemDelegate):
+ """
+ Delegate item to take care of the rendering of the default table cells and
+ also the header cells.
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Parent of the widget
+ """
+ qt.QStyledItemDelegate.__init__(self, parent)
+
+ def paint(self, painter, option, index):
+ """Override the paint function to inject the style of the header.
+
+ :param qt.QPainter painter: Painter context used to displayed the cell
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ isHeader = index.data(role=HierarchicalTableModel.IsHeaderRole)
+ if isHeader:
+ span = index.data(role=HierarchicalTableModel.SpanRole)
+ span = 1 if span is None else span[1]
+ columnCount = index.model().columnCount()
+ if span == columnCount:
+ mainTitle = True
+ position = qt.QStyleOptionHeader.OnlyOneSection
+ else:
+ mainTitle = False
+ col = index.column()
+ if col == 0:
+ position = qt.QStyleOptionHeader.Beginning
+ elif col < columnCount - 1:
+ position = qt.QStyleOptionHeader.Middle
+ else:
+ position = qt.QStyleOptionHeader.End
+ opt = qt.QStyleOptionHeader()
+ opt.direction = option.direction
+ opt.text = index.data()
+ opt.textAlignment = qt.Qt.AlignCenter if mainTitle else qt.Qt.AlignVCenter
+ opt.direction = option.direction
+ opt.fontMetrics = option.fontMetrics
+ opt.palette = option.palette
+ opt.rect = option.rect
+ opt.state = option.state
+ opt.position = position
+ margin = -1
+ style = qt.QApplication.instance().style()
+ opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin)
+ style.drawControl(qt.QStyle.CE_HeaderSection, opt, painter, None)
+ margin = 3
+ opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin)
+ style.drawControl(qt.QStyle.CE_HeaderLabel, opt, painter, None)
+ else:
+ qt.QStyledItemDelegate.paint(self, painter, option, index)
+
+
+class HierarchicalTableView(qt.QTableView):
+ """A TableView which allow to display a `HierarchicalTableModel`."""
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent of the widget
+ """
+ super(HierarchicalTableView, self).__init__(parent)
+ self.setItemDelegate(HierarchicalItemDelegate(self))
+ self.verticalHeader().setVisible(False)
+ self.horizontalHeader().setVisible(False)
+
+ def setModel(self, model):
+ """Override the default function to connect the model to update
+ function"""
+ if self.model() is not None:
+ model.modelReset.disconnect(self.__modelReset)
+ super(HierarchicalTableView, self).setModel(model)
+ if self.model() is not None:
+ model.modelReset.connect(self.__modelReset)
+ self.__modelReset()
+
+ def __modelReset(self):
+ """Update the model to take care of the changes of the span
+ information"""
+ self.clearSpans()
+ model = self.model()
+ for row in range(model.rowCount()):
+ for column in range(model.columnCount()):
+ index = model.index(row, column, qt.QModelIndex())
+ span = model.data(index, HierarchicalTableModel.SpanRole)
+ if span is not None and span != (1, 1):
+ self.setSpan(row, column, span[0], span[1])
diff --git a/silx/gui/widgets/MedianFilterDialog.py b/silx/gui/widgets/MedianFilterDialog.py
new file mode 100644
index 0000000..3eddff3
--- /dev/null
+++ b/silx/gui/widgets/MedianFilterDialog.py
@@ -0,0 +1,74 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+""" MedianFilterDialog
+Classes
+-------
+
+Widgets:
+
+ - :class:`MedianFilterDialog`
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "14/02/2017"
+
+from silx.gui import qt
+
+class MedianFilterDialog(qt.QDialog):
+ """QDialog window featuring a :class:`BackgroundWidget`"""
+ sigFilterOptChanged = qt.Signal(int, bool)
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+
+ self.setWindowTitle("Median filter options")
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.setLayout(self.mainLayout)
+
+ # filter width GUI
+ self.mainLayout.addWidget(qt.QLabel('filter width:', parent = self))
+ self._filterWidth = qt.QSpinBox(parent=self)
+ self._filterWidth.setMinimum(1)
+ self._filterWidth.setValue(1)
+ self._filterWidth.setSingleStep(2);
+ widthTooltip = """radius width of the pixel including in the filter
+ for each pixel"""
+ self._filterWidth.setToolTip(widthTooltip)
+ self._filterWidth.valueChanged.connect(self._filterOptionChanged)
+ self.mainLayout.addWidget(self._filterWidth)
+
+ # filter option GUI
+ self._filterOption = qt.QCheckBox('conditional', parent=self)
+ conditionalTooltip = """if check, implement a conditional filter"""
+ self._filterOption.stateChanged.connect(self._filterOptionChanged)
+ self.mainLayout.addWidget(self._filterOption)
+
+ def _filterOptionChanged(self):
+ """Call back used when the filter values are changed"""
+ if self._filterWidth.value()%2 == 0:
+ logging.warning('median filter only accept odd values')
+ else:
+ self.sigFilterOptChanged.emit(self._filterWidth.value(), self._filterOption.isChecked()) \ No newline at end of file
diff --git a/silx/gui/widgets/PeriodicTable.py b/silx/gui/widgets/PeriodicTable.py
new file mode 100644
index 0000000..2f1ca78
--- /dev/null
+++ b/silx/gui/widgets/PeriodicTable.py
@@ -0,0 +1,825 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Periodic table widgets
+
+Classes
+-------
+
+Widgets:
+
+ - :class:`PeriodicTable`
+ - :class:`PeriodicList`
+ - :class:`PeriodicCombo`
+
+Data model:
+
+ - :class:`PeriodicTableItem`
+ - :class:`ColoredPeriodicTableItem`
+
+
+Example of usage
+----------------
+
+This example uses the widgets with the standard builtin elements list.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+
+ a = qt.QApplication([])
+
+ w = qt.QTabWidget()
+
+ ptable = PeriodicTable(w, selectable=True)
+ pcombo = PeriodicCombo(w)
+ plist = PeriodicList(w)
+
+ w.addTab(ptable, "PeriodicTable")
+ w.addTab(plist, "PeriodicList")
+ w.addTab(pcombo, "PeriodicCombo")
+
+ ptable.setSelection(['H', 'Fe', 'Si'])
+ plist.setSelectedElements(['H', 'Be', 'F'])
+ pcombo.setSelection("Li")
+
+ def change_list(items):
+ print("New list selection:", [item.symbol for item in items])
+
+ def change_combo(item):
+ print("New combo selection:", item.symbol)
+
+ def click_table(item):
+ print("New table click:", item.symbol)
+
+ def change_table(items):
+ print("New table selection:", [item.symbol for item in items])
+
+ ptable.sigElementClicked.connect(click_table)
+ ptable.sigSelectionChanged.connect(change_table)
+ plist.sigSelectionChanged.connect(change_list)
+ pcombo.sigSelectionChanged.connect(change_combo)
+
+ w.show()
+ a.exec_()
+
+
+The second example explains how to define custom elements.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+ from silx.gui.widgets.PeriodicTable import PeriodicTableItem
+
+ # subclass PeriodicTableItem
+ class MyPeriodicTableItem(PeriodicTableItem):
+ "New item with added mass number and number of protons"
+ def __init__(self, symbol, Z, A, col, row, name, mass,
+ subcategory=""):
+ PeriodicTableItem.__init__(
+ self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.A = A
+ "Mass number (neutrons + protons)"
+
+ self.num_neutrons = A - Z
+ "Number of neutrons"
+
+ # build your list of elements
+ my_elements = [MyPeriodicTableItem("H", 1, 1, 1, 1, "hydrogen",
+ 1.00800, "diatomic nonmetal"),
+ MyPeriodicTableItem("He", 2, 4, 18, 1, "helium",
+ 4.0030, "noble gas"),
+ # etc ...
+ ]
+
+ app = qt.QApplication([])
+
+ ptable = PeriodicTable(elements=my_elements, selectable=True)
+ ptable.show()
+
+ def click_table(item):
+ "Callback function printing the mass number of clicked element"
+ print("New table click, mass number:", item.A)
+
+ ptable.sigElementClicked.connect(click_table)
+ app.exec_()
+
+"""
+
+__authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+from collections import OrderedDict
+import logging
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+# Symbol Atomic Number col row name mass subcategory
+_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
+ ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
+ ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
+ ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
+ ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
+ ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
+ ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
+ ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
+ ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
+ ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
+ ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
+ ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
+ ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
+ ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
+ ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
+ ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
+ ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
+ ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
+ ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
+ ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
+ ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
+ ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
+ ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
+ ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
+ ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
+ ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
+ ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
+ ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
+ ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
+ ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
+ ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
+ ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
+ ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
+ ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
+ ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
+ ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
+ ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
+ ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
+ ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
+ ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
+ ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
+ ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
+ ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
+ ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
+ ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
+ ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
+ ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
+ ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
+ ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
+ ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
+ ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
+ ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
+ ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
+ ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
+ ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
+ ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
+ ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
+ ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
+ ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
+ ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
+ ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
+ ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
+ ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
+ ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
+ ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
+ ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
+ ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
+ ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
+ ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
+ ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
+ ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
+ ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
+ ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
+ ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
+ ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
+ ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
+ ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
+ ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
+ ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
+ ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
+ ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
+ ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
+ ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
+ ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
+ ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
+ ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
+ ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
+ ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
+ ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
+ ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
+ ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
+ ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
+ ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
+ ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
+ ("Am", 95, 9, 10, "americium", 243, "actinide"),
+ ("Cm", 96, 10, 10, "curium", 247, "actinide"),
+ ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
+ ("Cf", 98, 12, 10, "californium", 251, "actinide"),
+ ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
+ ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
+ ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
+ ("No", 102, 16, 10, "nobelium", 259, "actinide"),
+ ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
+ ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
+ ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
+ ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
+ ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
+ ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
+ ("Mt", 109, 9, 7, "meitnerium", 268)]
+
+
+class PeriodicTableItem(object):
+ """Periodic table item, used as generic item in :class:`PeriodicTable`,
+ :class:`PeriodicCombo` and :class:`PeriodicList`.
+
+ This implementation stores the minimal amount of information needed by the
+ widgets:
+
+ - atomic symbol
+ - atomic number
+ - element name
+ - atomic mass
+ - column of element in periodic table
+ - row of element in periodic table
+
+ You can subclass this class to add additional information.
+
+ :param str symbol: Atomic symbol (e.g. H, He, Li...)
+ :param int Z: Proton number
+ :param int col: 1-based column index of element in periodic table
+ :param int row: 1-based row index of element in periodic table
+ :param str name: PeriodicTableItem name ("hydrogen", ...)
+ :param float mass: Atomic mass (gram per mol)
+ :param str subcategory: Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)
+ """
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory=""):
+ self.symbol = symbol
+ """Atomic symbol (e.g. H, He, Li...)"""
+ self.Z = Z
+ """Atomic number (Proton number)"""
+ self.col = col
+ """1-based column index of element in periodic table"""
+ self.row = row
+ """1-based row index of element in periodic table"""
+ self.name = name
+ """PeriodicTableItem name ("hydrogen", ...)"""
+ self.mass = mass
+ """Atomic mass (gram per mol)"""
+ self.subcategory = subcategory
+ """Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)"""
+
+ # pymca compatibility (elements used to be stored as a list of lists)
+ def __getitem__(self, idx):
+ if idx == 6:
+ _logger.warning("density not implemented in silx, returning 0.")
+
+ ret = [self.symbol, self.Z,
+ self.col, self.row,
+ self.name, self.mass,
+ 0.]
+ return ret[idx]
+
+ def __len__(self):
+ return 6
+
+
+class ColoredPeriodicTableItem(PeriodicTableItem):
+ """:class:`PeriodicTableItem` with an added :attr:`bgcolor`.
+ The background color can be passed as a parameter to the constructor.
+ If it is not specified, it will be defined based on
+ :attr:`subcategory`.
+
+ :param str bgcolor: Custom background color for element in
+ periodic table, as a RGB string *#RRGGBB*"""
+ COLORS = {
+ "diatomic nonmetal": "#7FFF00", # chartreuse
+ "noble gas": "#00FFFF", # cyan
+ "alkali metal": "#FFE4B5", # Moccasin
+ "alkaline earth metal": "#FFA500", # orange
+ "polyatomic nonmetal": "#7FFFD4", # aquamarine
+ "transition metal": "#FFA07A", # light salmon
+ "metalloid": "#8FBC8F", # Dark Sea Green
+ "post transition metal": "#D3D3D3", # light gray
+ "lanthanide": "#FFB6C1", # light pink
+ "actinide": "#F08080", # Light Coral
+ "": "#FFFFFF" # white
+ }
+ """Dictionary defining RGB colors for each subcategory."""
+
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory="", bgcolor=None):
+ PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF")
+ """Background color of element in the periodic table,
+ based on its subcategory. This should be a string of a hexadecimal
+ RGB code, with the format *#RRGGBB*.
+ If the subcategory is unknown, use white (*#FFFFFF*)
+ """
+
+ # possible custom color
+ if bgcolor is not None:
+ self.bgcolor = bgcolor
+
+
+_defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements]
+
+
+class _ElementButton(qt.QPushButton):
+ """Atomic element button, used as a cell in the periodic table
+ """
+ sigElementEnter = qt.pyqtSignal(object)
+ """Signal emitted as the cursor enters the widget"""
+ sigElementLeave = qt.pyqtSignal(object)
+ """Signal emitted as the cursor leaves the widget"""
+ sigElementClicked = qt.pyqtSignal(object)
+ """Signal emitted when the widget is clicked"""
+
+ def __init__(self, item, parent=None):
+ """
+
+ :param parent: Parent widget
+ :param PeriodicTableItem item: :class:`PeriodicTableItem` object
+ """
+ qt.QPushButton.__init__(self, parent)
+
+ self.item = item
+ """:class:`PeriodicTableItem` object represented by this button"""
+
+ self.setText(item.symbol)
+ self.setFlat(1)
+ self.setCheckable(0)
+
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Expanding))
+
+ self.selected = False
+ self.current = False
+
+ # selection colors
+ self.selected_color = qt.QColor(qt.Qt.yellow)
+ self.current_color = qt.QColor(qt.Qt.gray)
+ self.selected_current_color = qt.QColor(qt.Qt.darkYellow)
+
+ # element colors
+
+ if hasattr(item, "bgcolor"):
+ self.bgcolor = qt.QColor(item.bgcolor)
+ else:
+ self.bgcolor = qt.QColor("#FFFFFF")
+
+ self.brush = qt.QBrush()
+ self.__setBrush()
+
+ self.clicked.connect(self.clickedSlot)
+
+ def sizeHint(self):
+ return qt.QSize(40, 40)
+
+ def setCurrent(self, b):
+ """Set this element button as current.
+ Multiple buttons can be selected.
+
+ :param b: boolean
+ """
+ self.current = b
+ self.__setBrush()
+
+ def isCurrent(self):
+ """
+ :return: True if element button is current
+ """
+ return self.current
+
+ def isSelected(self):
+ """
+ :return: True if element button is selected
+ """
+ return self.selected
+
+ def setSelected(self, b):
+ """Set this element button as selected.
+ Only a single button can be selected.
+
+ :param b: boolean
+ """
+ self.selected = b
+ self.__setBrush()
+
+ def __setBrush(self):
+ """Selected cells are yellow when not current.
+ The current cell is dark yellow when selected or grey when not
+ selected.
+ Other cells have no bg color by default, unless specified at
+ instantiation (:attr:`bgcolor`)"""
+ palette = self.palette()
+ # if self.current and self.selected:
+ # self.brush = qt.QBrush(self.selected_current_color)
+ # el
+ if self.selected:
+ self.brush = qt.QBrush(self.selected_color)
+ # elif self.current:
+ # self.brush = qt.QBrush(self.current_color)
+ elif self.bgcolor is not None:
+ self.brush = qt.QBrush(self.bgcolor)
+ else:
+ self.brush = qt.QBrush()
+ palette.setBrush(self.backgroundRole(),
+ self.brush)
+ self.setPalette(palette)
+ self.update()
+
+ def paintEvent(self, pEvent):
+ # get button geometry
+ widgGeom = self.rect()
+ paintGeom = qt.QRect(widgGeom.left() + 1,
+ widgGeom.top() + 1,
+ widgGeom.width() - 2,
+ widgGeom.height() - 2)
+
+ # paint background color
+ painter = qt.QPainter(self)
+ if self.brush is not None:
+ painter.fillRect(paintGeom, self.brush)
+ # paint frame
+ pen = qt.QPen(qt.Qt.black)
+ pen.setWidth(1 if not self.isCurrent() else 5)
+ painter.setPen(pen)
+ painter.drawRect(paintGeom)
+ painter.end()
+ qt.QPushButton.paintEvent(self, pEvent)
+
+ def enterEvent(self, e):
+ """Emit a :attr:`sigElementEnter` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementEnter.emit(self.item)
+
+ def leaveEvent(self, e):
+ """Emit a :attr:`sigElementLeave` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementLeave.emit(self.item)
+
+ def clickedSlot(self):
+ """Emit a :attr:`sigElementClicked` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementClicked.emit(self.item)
+
+
+class PeriodicTable(qt.QWidget):
+ """Periodic Table widget
+
+ The following example shows how to connect clicking to selection::
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable
+ app = qt.QApplication([])
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+ pt.show()
+ app.exec_()
+
+ To print all selected elements each time a new element is selected::
+
+ def my_slot(item):
+ pt.elementToggle(item)
+ selected_elements = pt.getSelection()
+ for e in selected_elements:
+ print(e.symbol)
+
+ pt.sigElementClicked.connect(my_slot)
+
+ """
+ sigElementClicked = qt.pyqtSignal(object)
+ """When any element is clicked in the table, the widget emits
+ this signal and sends a :class:`PeriodicTableItem` object.
+ """
+
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the table, the widget emits
+ this signal and sends a list of :class:`PeriodicTableItem` objects.
+
+ .. note::
+
+ To enable selection of elements, you must set *selectable=True*
+ when you instantiate the widget. Alternatively, you can also connect
+ :attr:`sigElementClicked` to :meth:`elementToggle` manually::
+
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+
+
+ :param parent: parent QWidget
+ :param str name: Widget window title
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ :param bool selectable: If *True*, multiple elements can be
+ selected by clicking with the mouse. If *False* (default),
+ selection is only possible with method :meth:`setSelection`.
+ """
+
+ def __init__(self, parent=None, name="PeriodicTable", elements=None,
+ selectable=False):
+ self.selectable = selectable
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.gridLayout = qt.QGridLayout(self)
+ self.gridLayout.setContentsMargins(0, 0, 0, 0)
+ self.gridLayout.addItem(qt.QSpacerItem(0, 5), 7, 0)
+
+ for idx in range(10):
+ self.gridLayout.setRowStretch(idx, 3)
+ # row 8 (above lanthanoids is empty)
+ self.gridLayout.setRowStretch(7, 2)
+
+ # Element information displayed when cursor enters a cell
+ self.eltLabel = qt.QLabel(self)
+ f = self.eltLabel.font()
+ f.setBold(1)
+ self.eltLabel.setFont(f)
+ self.eltLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.gridLayout.addWidget(self.eltLabel, 1, 1, 3, 10)
+
+ self._eltCurrent = None
+ """Current :class:`_ElementButton` (last clicked)"""
+
+ self._eltButtons = OrderedDict()
+ """Dictionary of all :class:`_ElementButton`. Keys are the symbols
+ ("H", "He", "Li"...)"""
+
+ if elements is None:
+ elements = _defaultTableItems
+ # fill cells with elements
+ for elmt in elements:
+ self.__addElement(elmt)
+
+ def __addElement(self, elmt):
+ """Add one :class:`_ElementButton` widget into the grid,
+ connect its signals to interact with the cursor"""
+ b = _ElementButton(elmt, self)
+ b.setAutoDefault(False)
+
+ self._eltButtons[elmt.symbol] = b
+ self.gridLayout.addWidget(b, elmt.row, elmt.col)
+
+ b.sigElementEnter.connect(self.elementEnter)
+ b.sigElementLeave.connect(self._elementLeave)
+ b.sigElementClicked.connect(self._elementClicked)
+
+ def elementEnter(self, item):
+ """Update label with element info (e.g. "Nb(41) - niobium")
+ when mouse cursor hovers an element.
+
+ :param PeriodicTableItem item: Element entered by cursor
+ """
+ self.eltLabel.setText("%s(%d) - %s" % (item.symbol, item.Z, item.name))
+
+ def _elementLeave(self, item):
+ """Clear label when the cursor leaves the cell
+
+ :param PeriodicTableItem item: Element left
+ """
+ self.eltLabel.setText("")
+
+ def _elementClicked(self, item):
+ """Emit :attr:`sigElementClicked`,
+ toggle selected state of element
+
+ :param PeriodicTableItem item: Element clicked
+ """
+ if self._eltCurrent is not None:
+ self._eltCurrent.setCurrent(False)
+ self._eltButtons[item.symbol].setCurrent(True)
+ self._eltCurrent = self._eltButtons[item.symbol]
+ if self.selectable:
+ self.elementToggle(item)
+ self.sigElementClicked.emit(item)
+
+ def getSelection(self):
+ """Return a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected items
+ :rtype: list(PeriodicTableItem)
+ """
+ return [b.item for b in self._eltButtons.values() if b.isSelected()]
+
+ def setSelection(self, symbols):
+ """Set selected elements.
+
+ This causes the sigSelectionChanged signal
+ to be emitted, even if the selection didn't actually change.
+
+ :param list(str) symbols: List of symbols of elements to be selected
+ (e.g. *["Fe", "Hg", "Li"]*)
+ """
+ # accept list of PeriodicTableItems as input, because getSelection
+ # returns these objects and it makes sense to have getter and setter
+ # use same type of data
+ if isinstance(symbols[0], PeriodicTableItem):
+ symbols = [elmt.symbol for elmt in symbols]
+
+ for (e, b) in self._eltButtons.items():
+ b.setSelected(e in symbols)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def setElementSelected(self, symbol, state):
+ """Modify *selected* status of a single element (select or unselect)
+
+ :param str symbol: PeriodicTableItem symbol to be selected
+ :param bool state: *True* to select, *False* to unselect
+ """
+ self._eltButtons[symbol].setSelected(state)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def isElementSelected(self, symbol):
+ """Return *True* if element is selected, else *False*
+
+ :param str symbol: PeriodicTableItem symbol
+ :return: *True* if element is selected, else *False*
+ """
+ return self._eltButtons[symbol].isSelected()
+
+ def elementToggle(self, item):
+ """Toggle selected/unselected state for element
+
+ :param item: PeriodicTableItem object
+ """
+ b = self._eltButtons[item.symbol]
+ b.setSelected(not b.isSelected())
+ self.sigSelectionChanged.emit(self.getSelection())
+
+
+class PeriodicCombo(qt.QComboBox):
+ """
+ Combo list with all atomic elements of the periodic table
+
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """Signal emitted when the selection changes. Send
+ :class:`PeriodicTableItem` object representing selected
+ element
+ """
+
+ def __init__(self, parent=None, detailed=True, elements=None):
+ qt.QComboBox.__init__(self, parent)
+
+ # add all elements from global list
+ if elements is None:
+ elements = _defaultTableItems
+ for i, elmt in enumerate(elements):
+ if detailed:
+ txt = "%2s (%d) - %s" % (elmt.symbol, elmt.Z, elmt.name)
+ else:
+ txt = "%2s (%d)" % (elmt.symbol, elmt.Z)
+ self.insertItem(i, txt)
+
+ self.currentIndexChanged[int].connect(self.__selectionChanged)
+
+ def __selectionChanged(self, idx):
+ """Emit :attr:`sigSelectionChanged`"""
+ self.sigSelectionChanged.emit(_defaultTableItems[idx])
+
+ def getSelection(self):
+ """Get selected element
+
+ :return: Selected element
+ :rtype: PeriodicTableItem
+ """
+ return _defaultTableItems[self.currentIndex()]
+
+ def setSelection(self, symbol):
+ """Set selected item in combobox by giving the atomic symbol
+
+ :param symbol: Symbol of element to be selected
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbol, PeriodicTableItem):
+ symbol = symbol.symbol
+ symblist = [elmt.symbol for elmt in _defaultTableItems]
+ self.setCurrentIndex(symblist.index(symbol))
+
+
+class PeriodicList(qt.QTreeWidget):
+ """List of atomic elements in a :class:`QTreeView`
+
+ :param QWidget parent: Parent widget
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param single: *True* for single element selection with mouse click,
+ *False* for multiple element selection mode.
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the widget, it emits
+ this signal and sends a list of currently selected
+ :class:`PeriodicTableItem` objects.
+ """
+
+ def __init__(self, parent=None, detailed=True, single=False, elements=None):
+ qt.QTreeWidget.__init__(self, parent)
+
+ self.detailed = detailed
+
+ headers = ["Z", "Symbol"]
+ if detailed:
+ headers.append("Name")
+ self.setColumnCount(3)
+ else:
+ self.setColumnCount(2)
+ self.setHeaderLabels(headers)
+ self.header().setStretchLastSection(False)
+
+ self.setRootIsDecorated(0)
+ self.itemClicked.connect(self.__selectionChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single
+ else qt.QAbstractItemView.ExtendedSelection)
+ self.__fill_widget(elements)
+ self.resizeColumnToContents(0)
+ self.resizeColumnToContents(1)
+ if detailed:
+ self.resizeColumnToContents(2)
+
+ def __fill_widget(self, elements):
+ """Fill tree widget with elements """
+ if elements is None:
+ elements = _defaultTableItems
+
+ self.tree_items = []
+
+ previous_item = None
+ for elmt in elements:
+ if previous_item is None:
+ item = qt.QTreeWidgetItem(self)
+ else:
+ item = qt.QTreeWidgetItem(self, previous_item)
+ item.setText(0, str(elmt.Z))
+ item.setText(1, elmt.symbol)
+ if self.detailed:
+ item.setText(2, elmt.name)
+ self.tree_items.append(item)
+ previous_item = item
+
+ def __selectionChanged(self, treeItem, column):
+ """Emit a :attr:`sigSelectionChanged` and send a list of
+ :class:`PeriodicTableItem` objects."""
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def getSelection(self):
+ """Get a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected elements
+ :rtype: list(PeriodicTableItem)"""
+ return [_defaultTableItems[idx] for idx in range(len(self.tree_items))
+ if self.tree_items[idx].isSelected()]
+
+ # setSelection is a bad name (name of a QTreeWidget method)
+ def setSelectedElements(self, symbolList):
+ """
+
+ :param symbolList: List of atomic symbols ["H", "He", "Li"...]
+ to be selected in the widget
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbolList[0], PeriodicTableItem):
+ symbolList = [elmt.symbol for elmt in symbolList]
+ for idx in range(len(self.tree_items)):
+ self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList)
diff --git a/silx/gui/widgets/TableWidget.py b/silx/gui/widgets/TableWidget.py
new file mode 100644
index 0000000..fad80ee
--- /dev/null
+++ b/silx/gui/widgets/TableWidget.py
@@ -0,0 +1,488 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides table widgets handling cut, copy and paste for
+multiple cell selections. These actions can be triggered using keyboard
+shortcuts or through a context menu (right-click).
+
+:class:`TableView` is a subclass of :class:`QTableView`. The added features
+are made available to users after a model is added to the widget, using
+:meth:`TableView.setModel`.
+
+:class:`TableWidget` is a subclass of :class:`qt.QTableWidget`, a table view
+with a built-in standard data model. The added features are available as soon as
+the widget is initialized.
+
+The cut, copy and paste actions are implemented as QActions:
+
+ - :class:`CopySelectedCellsAction` (*Ctrl+C*)
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction` (*Ctrl+X*)
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction` (*Ctrl+V*)
+
+The copy actions are enabled by default. The cut and paste actions must be
+explicitly enabled, by passing parameters ``cut=True, paste=True`` when
+creating the widgets, or later by calling their :meth:`enableCut` and
+:meth:`enablePaste` methods.
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+import sys
+from .. import qt
+
+
+if sys.platform.startswith("win"):
+ row_separator = "\r\n"
+else:
+ row_separator = "\n"
+
+col_separator = "\t"
+
+
+class CopySelectedCellsAction(qt.QAction):
+ """QAction to copy text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ If multiple cells are selected, the copied text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopySelectedCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopySelectedCellsAction, self).__init__(table)
+ self.setText("Copy selection")
+ self.setToolTip("Copy selected cells into the clipboard.")
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+ """:attr:`cut` can be set to True by classes inheriting this action,
+ to do a cut action."""
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all selected cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ selected_idx = self.table.selectedIndexes()
+ selected_idx_tuples = [(idx.row(), idx.column()) for idx in selected_idx]
+
+ selected_rows = [idx[0] for idx in selected_idx_tuples]
+ selected_columns = [idx[1] for idx in selected_idx_tuples]
+
+ data_model = self.table.model()
+
+ copied_text = ""
+ for row in range(min(selected_rows), max(selected_rows) + 1):
+ for col in range(min(selected_columns), max(selected_columns) + 1):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+
+ if (row, col) in selected_idx_tuples and cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CopyAllCellsAction(qt.QAction):
+ """QAction to copy text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The copied text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopyAllCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopyAllCellsAction, self).__init__(table)
+ self.setText("Copy all")
+ self.setToolTip("Copy all cells into the clipboard.")
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ data_model = self.table.model()
+ copied_text = ""
+ for row in range(data_model.rowCount()):
+ for col in range(data_model.columnCount()):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+ if cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CutSelectedCellsAction(CopySelectedCellsAction):
+ """QAction to cut text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopySelectedCellsAction` to preserve the original data).
+
+ If multiple cells are selected, the cut text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutSelectedCellsAction, self).__init__(table)
+ self.setText("Cut selection")
+ self.setShortcut(qt.QKeySequence.Cut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ # cutting is already implemented in CopySelectedCellsAction (but
+ # it is disabled), we just need to enable it
+ self.cut = True
+
+
+class CutAllCellsAction(CopyAllCellsAction):
+ """QAction to cut text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopyAllCellsAction` to preserve the original data).
+
+ The cut text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutAllCellsAction, self).__init__(table)
+ self.setText("Cut all")
+ self.setToolTip("Cut all cells into the clipboard.")
+ self.cut = True
+
+
+def _parseTextAsTable(text, row_separator=row_separator, col_separator=col_separator):
+ """Parse text into list of lists (2D sequence).
+
+ The input text must be tabulated using tabulation characters and
+ newlines to separate columns and rows.
+
+ :param text: text to be parsed
+ :param record_separator: String, or single character, to be interpreted
+ as a record/row separator.
+ :param field_separator: String, or single character, to be interpreted
+ as a field/column separator.
+ :return: 2D sequence of strings
+ """
+ rows = text.split(row_separator)
+ table_data = [row.split(col_separator) for row in rows]
+ return table_data
+
+
+class PasteCellsAction(qt.QAction):
+ """QAction to paste text from the clipboard into the table.
+
+ If the text contains tabulations and
+ newlines, they are interpreted as column and row separators.
+ In such a case, the text is split into multiple texts to be pasted
+ into multiple cells.
+
+ If a cell content is an empty string in the original text, it is
+ ignored: the destination cell's text will not be deleted.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('PasteCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(PasteCellsAction, self).__init__(table)
+ self.table = table
+ self.setText("Paste")
+ self.setShortcut(qt.QKeySequence.Paste)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.setToolTip("Paste data. The selected cell is the top-left" +
+ "corner of the paste area.")
+ self.triggered.connect(self.pasteCellFromClipboard)
+
+ def pasteCellFromClipboard(self):
+ """Paste text from clipboard into the table.
+
+ :return: *True* in case of success, *False* if pasting data failed.
+ """
+ selected_idx = self.table.selectedIndexes()
+ if len(selected_idx) != 1:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msgBox.setText("A single cell must be selected to paste data")
+ msgBox.exec_()
+ return False
+
+ data_model = self.table.model()
+
+ selected_row = selected_idx[0].row()
+ selected_col = selected_idx[0].column()
+
+ qapp = qt.QApplication.instance()
+ clipboard_text = qapp.clipboard().text()
+ table_data = _parseTextAsTable(clipboard_text)
+
+ protected_cells = 0
+ out_of_range_cells = 0
+
+ # paste table data into cells, using selected cell as origin
+ for row_offset in range(len(table_data)):
+ for col_offset in range(len(table_data[row_offset])):
+ target_row = selected_row + row_offset
+ target_col = selected_col + col_offset
+
+ if target_row >= data_model.rowCount() or\
+ target_col >= data_model.columnCount():
+ out_of_range_cells += 1
+ continue
+
+ index = data_model.index(target_row, target_col)
+ flags = data_model.flags(index)
+
+ # ignore empty strings
+ if table_data[row_offset][col_offset] != "":
+ if not flags & qt.Qt.ItemIsEditable:
+ protected_cells += 1
+ continue
+ data_model.setData(index, table_data[row_offset][col_offset])
+ # item.setText(table_data[row_offset][col_offset])
+
+ if protected_cells or out_of_range_cells:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msg = "Some data could not be inserted, "
+ msg += "due to out-of-range or write-protected cells."
+ msgBox.setText(msg)
+ msgBox.exec_()
+ return False
+ return True
+
+
+class TableWidget(qt.QTableWidget):
+ """:class:`QTableWidget` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableWidget, self).__init__(parent)
+ self.addAction(CopySelectedCellsAction(self))
+ self.addAction(CopyAllCellsAction(self))
+ if cut:
+ self.enableCut()
+ if paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.addAction(PasteCellsAction(self))
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data."""
+ self.addAction(CutSelectedCellsAction(self))
+ self.addAction(CutAllCellsAction(self))
+
+
+class TableView(qt.QTableView):
+ """:class:`QTableView` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ .. note::
+
+ These actions will be available only after a model is associated
+ with this view, using :meth:`setModel`.
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableView, self).__init__(parent)
+ self.cut = cut
+ self.paste = paste
+
+ def setModel(self, model):
+ """Set the data model for the table view, activate the actions
+ and the context menu.
+
+ :param model: :class:`qt.QAbstractItemModel` object
+ """
+ super(TableView, self).setModel(model)
+
+ self.addAction(CopySelectedCellsAction(self))
+ self.addAction(CopyAllCellsAction(self))
+ if self.cut:
+ self.enableCut()
+ if self.paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.addAction(PasteCellsAction(self))
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.addAction(CutSelectedCellsAction(self))
+ self.addAction(CutAllCellsAction(self))
+
+ def addAction(self, action):
+ # ensure the actions are not added multiple times:
+ # compare action type and parent widget with those of existing actions
+ for existing_action in self.actions():
+ if type(action) == type(existing_action):
+ if hasattr(action, "table") and\
+ action.table is existing_action.table:
+ return None
+ super(TableView, self).addAction(action)
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+
+ tablewidget = TableWidget()
+ tablewidget.setWindowTitle("TableWidget")
+ tablewidget.setColumnCount(10)
+ tablewidget.setRowCount(7)
+ tablewidget.enableCut()
+ tablewidget.enablePaste()
+ tablewidget.show()
+
+ tableview = TableView(cut=True, paste=True)
+ tableview.setWindowTitle("TableView")
+ model = qt.QStandardItemModel()
+ model.setColumnCount(10)
+ model.setRowCount(7)
+ tableview.setModel(model)
+ tableview.show()
+
+ app.exec_()
diff --git a/silx/gui/widgets/ThreadPoolPushButton.py b/silx/gui/widgets/ThreadPoolPushButton.py
new file mode 100644
index 0000000..29e831d
--- /dev/null
+++ b/silx/gui/widgets/ThreadPoolPushButton.py
@@ -0,0 +1,233 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""ThreadPoolPushButton module
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+import logging
+from .. import qt
+from .WaitingPushButton import WaitingPushButton
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _Wrapper(qt.QRunnable):
+ """Wrapper to allow to call a function into a `QThreadPool` and
+ sending signals during the life cycle of the object"""
+
+ def __init__(self, signalHolder, function, args, kwargs):
+ """Constructor"""
+ super(_Wrapper, self).__init__()
+ self.__signalHolder = signalHolder
+ self.__callable = function
+ self.__args = args
+ self.__kwargs = kwargs
+
+ def run(self):
+ holder = self.__signalHolder
+ holder.started.emit()
+ try:
+ result = self.__callable(*self.__args, **self.__kwargs)
+ holder.succeeded.emit(result)
+ except Exception as e:
+ module = self.__callable.__module__
+ name = self.__callable.__name__
+ _logger.error("Error while executing callable %s.%s.", module, name, exc_info=True)
+ holder.failed.emit(e)
+ finally:
+ holder.finished.emit()
+ holder._sigReleaseRunner.emit(self)
+
+ def autoDelete(self):
+ """Returns true to ask the QThreadPool to manage the life cycle of
+ this QRunner."""
+ return True
+
+
+class ThreadPoolPushButton(WaitingPushButton):
+ """
+ ThreadPoolPushButton provides a simple push button to execute
+ a threaded task with user feedback when the task is running.
+
+ The task can be defined with the method `setCallable`. It takes a python
+ function and arguments as parameters.
+
+ WARNING: This task is run in a separate thread.
+
+ Everytime the button is pushed a new runner is created to execute the
+ function with defined arguments. An animated waiting icon is displayed
+ to show the activity. By default the button is disabled when an execution
+ is requested. This behaviour can be disabled by using
+ `setDisabledWhenWaiting`.
+
+ When the button is clicked a `beforeExecuting` signal is sent from the
+ Qt main thread. Then the task is started in a thread pool and the following
+ signals are emitted from the thread pool. Right before calling the
+ registered callable, the widget emits a `started` signal.
+ When the task ends, its result is emitted by the `succeeded` signal, but
+ if it fails the signal `failed` is emitted with the resulting exception.
+ At the end, the `finished` signal is emitted.
+
+ The task can be programatically executed by using `executeCallable`.
+
+ >>> # Compute a value
+ >>> import math
+ >>> button = ThreadPoolPushButton(text="Compute 2^16")
+ >>> button.setCallable(math.pow, 2, 16)
+ >>> button.succeeded.connect(print) # python3
+
+ >>> # Compute a wrong value
+ >>> import math
+ >>> button = ThreadPoolPushButton(text="Compute sqrt(-1)")
+ >>> button.setCallable(math.sqrt, -1)
+ >>> button.failed.connect(print) # python3
+ """
+
+ def __init__(self, parent=None, text=None, icon=None):
+ """Constructor
+
+ :param str text: Text displayed on the button
+ :param qt.QIcon icon: Icon displayed on the button
+ :param qt.QWidget parent: Parent of the widget
+ """
+ WaitingPushButton.__init__(self, parent=parent, text=text, icon=icon)
+ self.__callable = None
+ self.__args = None
+ self.__kwargs = None
+ self.__runnerCount = 0
+ self.__runnerSet = set([])
+ self.clicked.connect(self.executeCallable)
+ self.finished.connect(self.__runnerFinished)
+ self._sigReleaseRunner.connect(self.__releaseRunner)
+
+ beforeExecuting = qt.Signal()
+ """Signal emitted just before execution of the callable by the main Qt
+ thread. In synchronous mode (direct mode), it can be used to define
+ dynamically `setCallable`, or to execute something in the Qt thread before
+ the execution, or both."""
+
+ started = qt.Signal()
+ """Signal emitted from the thread pool when the defined callable is
+ started.
+
+ WARNING: This signal is emitted from the thread performing the task, and
+ might be received after the registered callable has been called. If you
+ want to perform some initialisation or set the callable to run, use the
+ `beforeExecuting` signal instead.
+ """
+
+ finished = qt.Signal()
+ """Signal emitted from the thread pool when the defined callable is
+ finished"""
+
+ succeeded = qt.Signal(object)
+ """Signal emitted from the thread pool when the callable exit with a
+ success.
+
+ The parameter of the signal is the result returned by the callable.
+ """
+
+ failed = qt.Signal(object)
+ """Signal emitted emitted from the thread pool when the callable raises an
+ exception.
+
+ The parameter of the signal is the raised exception.
+ """
+
+ _sigReleaseRunner = qt.Signal(object)
+ """Callback to release runners"""
+
+ def __runnerStarted(self):
+ """Called when a runner is started.
+
+ Count the number of executed tasks to change the state of the widget.
+ """
+ self.__runnerCount += 1
+ if self.__runnerCount > 0:
+ self.wait()
+
+ def __runnerFinished(self):
+ """Called when a runner is finished.
+
+ Count the number of executed tasks to change the state of the widget.
+ """
+ self.__runnerCount -= 1
+ if self.__runnerCount <= 0:
+ self.stopWaiting()
+
+ @qt.Slot()
+ def executeCallable(self):
+ """Execute the defined callable in QThreadPool.
+
+ First emit a `beforeExecuting` signal.
+ If callable is not defined, nothing append.
+ If a callable is defined, it will be started
+ as a new thread using the `QThreadPool` system. At start of the thread
+ the `started` will be emitted. When the callable returns a result it
+ is emitted by the `succeeded` signal. If the callable fail, the signal
+ `failed` is emitted with the resulting exception. Then the `finished`
+ signal is emitted.
+ """
+ self.beforeExecuting.emit()
+ if self.__callable is None:
+ return
+ self.__runnerStarted()
+ runner = self._createRunner(self.__callable, self.__args, self.__kwargs)
+ qt.QThreadPool.globalInstance().start(runner)
+ self.__runnerSet.add(runner)
+
+ def __releaseRunner(self, runner):
+ self.__runnerSet.remove(runner)
+
+ def _createRunner(self, function, args, kwargs):
+ """Create a QRunnable from a callable object.
+
+ :param callable function: A callable Python object.
+ :param list args: List of arguments to call the function.
+ :param dict kwargs: Dictionary of arguments used to call the function.
+ :rtpye: qt.QRunnable
+ """
+ runnable = _Wrapper(self, function, args, kwargs)
+ return runnable
+
+ def setCallable(self, function, *args, **kwargs):
+ """Define a callable which will be executed on QThreadPool everytime
+ the button is clicked.
+
+ To retrieve the results, connect to the `succeeded` signal.
+
+ WARNING: The callable will be called in a separate thread.
+
+ :param callable function: A callable Python object
+ :param list args: List of arguments to call the function.
+ :param dict kwargs: Dictionary of arguments used to call the function.
+ """
+ self.__callable = function
+ self.__args = args
+ self.__kwargs = kwargs
diff --git a/silx/gui/widgets/WaitingPushButton.py b/silx/gui/widgets/WaitingPushButton.py
new file mode 100644
index 0000000..49ab9b9
--- /dev/null
+++ b/silx/gui/widgets/WaitingPushButton.py
@@ -0,0 +1,243 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""WaitingPushButton module
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+from .. import qt
+from .. import icons
+
+
+class WaitingPushButton(qt.QPushButton):
+ """Button which allows to display a waiting status when, for example,
+ something is still computing.
+
+ The component is graphically disabled when it is in waiting. Then we
+ overwrite the enabled method to dissociate the 2 concepts:
+ graphically enabled/disabled, and enabled/disabled
+ """
+
+ def __init__(self, parent=None, text=None, icon=None):
+ """Constructor
+
+ :param str text: Text displayed on the button
+ :param qt.QIcon icon: Icon displayed on the button
+ :param qt.QWidget parent: Parent of the widget
+ """
+ if icon is not None:
+ qt.QPushButton.__init__(self, icon, text, parent)
+ elif text is not None:
+ qt.QPushButton.__init__(self, text, parent)
+ else:
+ qt.QPushButton.__init__(self, parent)
+
+ self.__waiting = False
+ self.__enabled = True
+ self.__icon = icon
+ self.__disabled_when_waiting = True
+ self.__waitingIcon = icons.getWaitIcon()
+
+ def sizeHint(self):
+ """Returns the recommended size for the widget.
+
+ This implementation of the recommended size always consider there is an
+ icon. In this way it avoid to update the layout when the waiting icon
+ is displayed.
+ """
+ self.ensurePolished()
+
+ w = 0
+ h = 0
+
+ opt = qt.QStyleOptionButton()
+ self.initStyleOption(opt)
+
+ # Content with icon
+ # no condition, assume that there is an icon to avoid blinking
+ # when the widget switch to waiting state
+ ih = opt.iconSize.height()
+ iw = opt.iconSize.width() + 4
+ w += iw
+ h = max(h, ih)
+
+ # Content with text
+ text = self.text()
+ isEmpty = text == ""
+ if isEmpty:
+ text = "XXXX"
+ fm = self.fontMetrics()
+ textSize = fm.size(qt.Qt.TextShowMnemonic, text)
+ if not isEmpty or w == 0:
+ w += textSize.width()
+ if not isEmpty or h == 0:
+ h = max(h, textSize.height())
+
+ # Content with menu indicator
+ opt.rect.setSize(qt.QSize(w, h)) # PM_MenuButtonIndicator depends on the height
+ if self.menu() is not None:
+ w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self)
+
+ contentSize = qt.QSize(w, h)
+ if qt.qVersion().startswith("4.8."):
+ # On PyQt4/PySide the method QCommonStyle sizeFromContents returns
+ # different size when the widget provides an icon or not.
+ # In Qt5 there is not this problem.
+ opt.icon = qt.QIcon()
+ sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self)
+ sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut())
+ return sizeHint
+
+ def setDisabledWhenWaiting(self, isDisabled):
+ """Enable or disable the auto disable behaviour when the button is waiting.
+
+ :param bool isDisabled: Enable the auto-disable behaviour
+ """
+ if self.__disabled_when_waiting == isDisabled:
+ return
+ self.__disabled_when_waiting = isDisabled
+ self.__updateVisibleEnabled()
+
+ def isDisabledWhenWaiting(self):
+ """Returns true if the button is auto disabled when it is waiting.
+
+ :rtype: bool
+ """
+ return self.__disabled_when_waiting
+
+ disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting)
+ """Property to enable/disable the auto disabled state when the button is waiting."""
+
+ def __setWaitingIcon(self, icon):
+ """Called when the waiting icon is updated. It is called every frames
+ of the animation.
+
+ :param qt.QIcon icon: The new waiting icon
+ """
+ qt.QPushButton.setIcon(self, icon)
+
+ def setIcon(self, icon):
+ """Set the button icon. If the button is waiting, the icon is not
+ visible directly, but will be visible when the waiting state will be
+ removed.
+
+ :param qt.QIcon icon: An icon
+ """
+ self.__icon = icon
+ self.__updateVisibleIcon()
+
+ def getIcon(self):
+ """Returns the icon set to the button. If the widget is waiting
+ it is not returning the visible icon, but the one requested by
+ the application (the one displayed when the widget is not in
+ waiting state).
+
+ :rtype: qt.QIcon
+ """
+ return self.__icon
+
+ icon = qt.Property(qt.QIcon, getIcon, setIcon)
+ """Property providing access to the icon."""
+
+ def __updateVisibleIcon(self):
+ """Update the visible icon according to the state of the widget."""
+ if not self.isWaiting():
+ icon = self.__icon
+ else:
+ icon = self.__waitingIcon.currentIcon()
+ if icon is None:
+ icon = qt.QIcon()
+ qt.QPushButton.setIcon(self, icon)
+
+ def setEnabled(self, enabled):
+ """Set the enabled state of the widget.
+
+ :param bool enabled: The enabled state
+ """
+ if self.__enabled == enabled:
+ return
+ self.__enabled = enabled
+ self.__updateVisibleEnabled()
+
+ def isEnabled(self):
+ """Returns the enabled state of the widget.
+
+ :rtype: bool
+ """
+ return self.__enabled
+
+ enabled = qt.Property(bool, isEnabled, setEnabled)
+ """Property providing access to the enabled state of the widget"""
+
+ def __updateVisibleEnabled(self):
+ """Update the visible enabled state according to the state of the
+ widget."""
+ if self.__disabled_when_waiting:
+ enabled = not self.isWaiting() and self.__enabled
+ else:
+ enabled = self.__enabled
+ qt.QPushButton.setEnabled(self, enabled)
+
+ def setWaiting(self, waiting):
+ """Set the waiting state of the widget.
+
+ :param bool waiting: Requested state"""
+ if self.__waiting == waiting:
+ return
+ self.__waiting = waiting
+
+ if self.__waiting:
+ self.__waitingIcon.register(self)
+ self.__waitingIcon.iconChanged.connect(self.__setWaitingIcon)
+ else:
+ # unregister only if the object is registred
+ self.__waitingIcon.unregister(self)
+ self.__waitingIcon.iconChanged.disconnect(self.__setWaitingIcon)
+
+ self.__updateVisibleEnabled()
+ self.__updateVisibleIcon()
+
+ def isWaiting(self):
+ """Returns true if the widget is in waiting state.
+
+ :rtype: bool"""
+ return self.__waiting
+
+ @qt.Slot()
+ def wait(self):
+ """Enable the waiting state."""
+ self.setWaiting(True)
+
+ @qt.Slot()
+ def stopWaiting(self):
+ """Disable the waiting state."""
+ self.setWaiting(False)
+
+ @qt.Slot()
+ def swapWaiting(self):
+ """Swap the waiting state."""
+ self.setWaiting(not self.isWaiting())
diff --git a/silx/gui/widgets/__init__.py b/silx/gui/widgets/__init__.py
new file mode 100644
index 0000000..034f4d3
--- /dev/null
+++ b/silx/gui/widgets/__init__.py
@@ -0,0 +1,27 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a few simple Qt widgets that rely only on a Qt wrapper
+for Python (PyQt5, PyQt4 or PySide). No other optional dependencies of *silx*
+should be required."""
diff --git a/silx/gui/widgets/setup.py b/silx/gui/widgets/setup.py
new file mode 100644
index 0000000..e96ac8d
--- /dev/null
+++ b/silx/gui/widgets/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "11/10/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('widgets', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py
new file mode 100644
index 0000000..afa0f78
--- /dev/null
+++ b/silx/gui/widgets/test/__init__.py
@@ -0,0 +1,45 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from . import test_periodictable
+from . import test_tablewidget
+from . import test_threadpoolpushbutton
+from . import test_hierarchicaltableview
+
+__authors__ = ["V. Valls", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTests(
+ [test_threadpoolpushbutton.suite(),
+ test_tablewidget.suite(),
+ test_periodictable.suite(),
+ test_hierarchicaltableview.suite(),
+ ])
+ return test_suite
diff --git a/silx/gui/widgets/test/test_hierarchicaltableview.py b/silx/gui/widgets/test/test_hierarchicaltableview.py
new file mode 100644
index 0000000..b3d37ed
--- /dev/null
+++ b/silx/gui/widgets/test/test_hierarchicaltableview.py
@@ -0,0 +1,117 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+import unittest
+
+from .. import HierarchicalTableView
+from ...test.utils import TestCaseQt
+from silx.gui import qt
+
+
+class TableModel(HierarchicalTableView.HierarchicalTableModel):
+
+ def __init__(self, parent):
+ HierarchicalTableView.HierarchicalTableModel.__init__(self, parent)
+ self.__content = {}
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def setData1(self):
+ if qt.qVersion() > "4.6":
+ self.beginResetModel()
+ else:
+ self.reset()
+
+ content = {}
+ content[0, 0] = ("title", True, (1, 3))
+ content[0, 1] = ("a", True, (2, 1))
+ content[1, 1] = ("b", False, (1, 2))
+ content[1, 2] = ("c", False, (1, 1))
+ content[2, 2] = ("d", False, (1, 1))
+ self.__content = content
+ if qt.qVersion() > "4.6":
+ self.endResetModel()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ if not index.isValid():
+ return None
+ cell = self.__content.get((index.column(), index.row()), None)
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell[2]
+ elif role == self.IsHeaderRole:
+ return cell[1]
+ elif role == qt.Qt.DisplayRole:
+ return cell[0]
+ return None
+
+
+class TestHierarchicalTableView(TestCaseQt):
+ """Test for HierarchicalTableView"""
+
+ def testEmpty(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModel(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ # set the data before using the model into the widget
+ model.setData1()
+ widget.setModel(model)
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModelUpdate(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ widget.setModel(model)
+ # set the data after using the model into the widget
+ model.setData1()
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
+
+
+def suite():
+ loader = unittest.defaultTestLoader.loadTestsFromTestCase
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(loader(TestHierarchicalTableView))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_periodictable.py b/silx/gui/widgets/test/test_periodictable.py
new file mode 100644
index 0000000..c6bed81
--- /dev/null
+++ b/silx/gui/widgets/test/test_periodictable.py
@@ -0,0 +1,163 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from .. import PeriodicTable
+from ...test.utils import TestCaseQt
+from silx.gui import qt
+
+
+class TestPeriodicTable(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable()
+ pt.show()
+ self.qWaitForWindowExposed(pt)
+
+ def testSelectable(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable(selectable=True)
+ self.assertTrue(pt.selectable)
+
+ def testCustomElements(self):
+ PTI = PeriodicTable.ColoredPeriodicTableItem
+ my_items = [
+ PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2,
+ bgcolor="#FF0000"),
+ PTI("Yy", 25, 22, 44, "yoyotrium", 8.8)
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["He", "Xx"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 1) # "He" not found
+ self.assertEqual(selection[0].symbol, "Xx")
+ self.assertEqual(selection[0].Z, 42)
+ self.assertEqual(selection[0].col, 43)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertEqual(qt.QColor(selection[0].bgcolor),
+ qt.QColor(qt.Qt.red))
+
+ self.assertTrue(pt.isElementSelected("Xx"))
+ self.assertFalse(pt.isElementSelected("Yy"))
+ self.assertRaises(KeyError, pt.isElementSelected, "Yx")
+
+ def testVeryCustomElements(self):
+ class MyPTI(PeriodicTable.PeriodicTableItem):
+ def __init__(self, *args):
+ PeriodicTable.PeriodicTableItem.__init__(self, *args[:6])
+ self.my_feature = args[6]
+
+ my_items = [
+ MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"),
+ MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs")
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["Xx", "Yy"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 2)
+ self.assertEqual(selection[1].symbol, "Yy")
+ self.assertEqual(selection[1].Z, 25)
+ self.assertEqual(selection[1].col, 22)
+ self.assertEqual(selection[1].row, 44)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertAlmostEqual(selection[0].my_feature, "spam")
+
+
+class TestPeriodicCombo(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicCombo, self).setUp()
+ self.pc = PeriodicTable.PeriodicCombo()
+
+ def tearDown(self):
+ del self.pc
+ super(TestPeriodicCombo, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pc.show()
+ self.qWaitForWindowExposed(self.pc)
+
+ def testSelect(self):
+ self.pc.setSelection("Sb")
+ selection = self.pc.getSelection()
+ self.assertIsInstance(selection,
+ PeriodicTable.PeriodicTableItem)
+ self.assertEqual(selection.symbol, "Sb")
+ self.assertEqual(selection.Z, 51)
+ self.assertEqual(selection.name, "antimony")
+
+
+class TestPeriodicList(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicList, self).setUp()
+ self.pl = PeriodicTable.PeriodicList()
+
+ def tearDown(self):
+ del self.pl
+ super(TestPeriodicList, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pl.show()
+ self.qWaitForWindowExposed(self.pl)
+
+ def testSelect(self):
+ self.pl.setSelectedElements(["Li", "He", "Au"])
+ sel_elmts = self.pl.getSelection()
+
+ self.assertEqual(len(sel_elmts), 3,
+ "Wrong number of elements selected")
+ for e in sel_elmts:
+ self.assertIsInstance(e, PeriodicTable.PeriodicTableItem)
+ self.assertIn(e.symbol, ["Li", "He", "Au"])
+ self.assertIn(e.Z, [2, 3, 79])
+ self.assertIn(e.name, ["lithium", "helium", "gold"])
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicTable))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicList))
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicCombo))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_tablewidget.py b/silx/gui/widgets/test/test_tablewidget.py
new file mode 100644
index 0000000..5ad0a06
--- /dev/null
+++ b/silx/gui/widgets/test/test_tablewidget.py
@@ -0,0 +1,61 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test TableWidget"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+class TestTableWidget(TestCaseQt):
+ def setUp(self):
+ super(TestTableWidget, self).setUp()
+ self._result = []
+
+ def testShow(self):
+ table = TableWidget()
+ table.setColumnCount(10)
+ table.setRowCount(7)
+ table.enableCut()
+ table.enablePaste()
+ table.show()
+ table.hide()
+ self.qapp.processEvents()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestTableWidget))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_threadpoolpushbutton.py b/silx/gui/widgets/test/test_threadpoolpushbutton.py
new file mode 100644
index 0000000..126f8f3
--- /dev/null
+++ b/silx/gui/widgets/test/test_threadpoolpushbutton.py
@@ -0,0 +1,129 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "15/12/2016"
+
+
+import unittest
+import time
+from silx.gui import qt
+from silx.gui.test.utils import TestCaseQt
+from silx.gui.test.utils import SignalListener
+from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton
+from silx.test.utils import TestLogging
+
+
+class TestThreadPoolPushButton(TestCaseQt):
+
+ def setUp(self):
+ super(TestThreadPoolPushButton, self).setUp()
+ self._result = []
+
+ def _trace(self, name, delay=0):
+ self._result.append(name)
+ if delay != 0:
+ time.sleep(delay / 1000.0)
+
+ def _compute(self):
+ return "result"
+
+ def _computeFail(self):
+ raise Exception("exception")
+
+ def testExecute(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ button.executeCallable()
+ time.sleep(0.1)
+ self.assertListEqual(self._result, ["a"])
+ self.qapp.processEvents()
+
+ def testMultiExecution(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ number = qt.QThreadPool.globalInstance().maxThreadCount() * 2
+ for _ in range(number):
+ button.executeCallable()
+ time.sleep(number * 0.01 + 0.1)
+ self.assertListEqual(self._result, ["a"] * number)
+ self.qapp.processEvents()
+
+ def testSaturateThreadPool(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 100)
+ number = qt.QThreadPool.globalInstance().maxThreadCount() * 2
+ for _ in range(number):
+ button.executeCallable()
+ time.sleep(number * 0.1 + 0.1)
+ self.assertListEqual(self._result, ["a"] * number)
+ self.qapp.processEvents()
+
+ def testSuccess(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._compute)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="result"))
+ button.failed.connect(listener.partial(test="Unexpected exception"))
+ button.finished.connect(listener.partial(test="f"))
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "result", "f"])
+
+ def testFail(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._computeFail)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="Unexpected success"))
+ button.failed.connect(listener.partial(test="exception"))
+ button.finished.connect(listener.partial(test="f"))
+ with TestLogging('silx.gui.widgets.ThreadPoolPushButton', error=1):
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "exception", "f"])
+ listener.clear()
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestThreadPoolPushButton))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')