summaryrefslogtreecommitdiff
path: root/silx
diff options
context:
space:
mode:
Diffstat (limited to 'silx')
-rw-r--r--silx/app/test/test_convert.py4
-rw-r--r--silx/app/view/Viewer.py2
-rw-r--r--silx/app/view/main.py6
-rw-r--r--silx/gui/_glutils/FramebufferTexture.py3
-rw-r--r--silx/gui/_glutils/OpenGLWidget.py14
-rw-r--r--silx/gui/_glutils/Texture.py319
-rw-r--r--silx/gui/_glutils/utils.py30
-rwxr-xr-xsilx/gui/colors.py117
-rw-r--r--silx/gui/data/DataViews.py2
-rw-r--r--silx/gui/data/Hdf5TableView.py68
-rw-r--r--silx/gui/data/NXdataWidgets.py1
-rw-r--r--silx/gui/data/TextFormatter.py8
-rw-r--r--silx/gui/data/test/test_dataviewer.py8
-rw-r--r--silx/gui/data/test/test_textformatter.py28
-rw-r--r--silx/gui/fit/BackgroundWidget.py4
-rw-r--r--silx/gui/fit/FitWidget.py2
-rwxr-xr-xsilx/gui/hdf5/Hdf5Item.py24
-rwxr-xr-xsilx/gui/hdf5/test/test_hdf5.py162
-rw-r--r--silx/gui/plot/ColorBar.py5
-rw-r--r--silx/gui/plot/ComplexImageView.py2
-rw-r--r--silx/gui/plot/CurvesROIWidget.py6
-rw-r--r--silx/gui/plot/ImageStack.py25
-rw-r--r--silx/gui/plot/ImageView.py12
-rw-r--r--silx/gui/plot/MaskToolsWidget.py30
-rw-r--r--silx/gui/plot/PlotInteraction.py19
-rwxr-xr-xsilx/gui/plot/PlotWidget.py186
-rw-r--r--silx/gui/plot/PlotWindow.py107
-rw-r--r--silx/gui/plot/ROIStatsWidget.py780
-rw-r--r--silx/gui/plot/ScatterMaskToolsWidget.py24
-rw-r--r--silx/gui/plot/StackView.py66
-rw-r--r--silx/gui/plot/StatsWidget.py32
-rw-r--r--silx/gui/plot/_BaseMaskToolsWidget.py14
-rw-r--r--silx/gui/plot/_utils/dtime_ticklayout.py16
-rwxr-xr-xsilx/gui/plot/actions/control.py79
-rw-r--r--silx/gui/plot/actions/io.py71
-rwxr-xr-xsilx/gui/plot/backends/BackendBase.py25
-rwxr-xr-xsilx/gui/plot/backends/BackendMatplotlib.py149
-rwxr-xr-xsilx/gui/plot/backends/BackendOpenGL.py426
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py86
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotFrame.py159
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotImage.py103
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotItem.py94
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotTriangles.py14
-rw-r--r--silx/gui/plot/backends/glutils/GLText.py60
-rw-r--r--silx/gui/plot/backends/glutils/GLTexture.py5
-rw-r--r--silx/gui/plot/backends/glutils/__init__.py3
-rw-r--r--silx/gui/plot/items/__init__.py3
-rw-r--r--silx/gui/plot/items/_arc_roi.py873
-rw-r--r--silx/gui/plot/items/_pick.py2
-rw-r--r--silx/gui/plot/items/_roi_base.py835
-rw-r--r--silx/gui/plot/items/complex.py15
-rw-r--r--silx/gui/plot/items/core.py189
-rw-r--r--silx/gui/plot/items/curve.py23
-rw-r--r--silx/gui/plot/items/histogram.py35
-rw-r--r--silx/gui/plot/items/image.py79
-rw-r--r--silx/gui/plot/items/roi.py1438
-rw-r--r--silx/gui/plot/items/scatter.py19
-rw-r--r--silx/gui/plot/items/shape.py35
-rw-r--r--silx/gui/plot/matplotlib/__init__.py50
-rw-r--r--silx/gui/plot/stats/stats.py497
-rw-r--r--silx/gui/plot/stats/statshandler.py12
-rw-r--r--silx/gui/plot/test/__init__.py2
-rw-r--r--silx/gui/plot/test/testComplexImageView.py8
-rw-r--r--silx/gui/plot/test/testCurvesROIWidget.py10
-rw-r--r--silx/gui/plot/test/testItem.py90
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py15
-rw-r--r--silx/gui/plot/test/testPlotInteraction.py6
-rwxr-xr-xsilx/gui/plot/test/testPlotWidget.py237
-rw-r--r--silx/gui/plot/test/testPlotWindow.py21
-rw-r--r--silx/gui/plot/test/testRoiStatsWidget.py290
-rw-r--r--silx/gui/plot/test/testScatterMaskToolsWidget.py16
-rw-r--r--silx/gui/plot/test/testStackView.py15
-rw-r--r--silx/gui/plot/test/testStats.py273
-rw-r--r--silx/gui/plot/tools/profile/manager.py31
-rw-r--r--silx/gui/plot/tools/profile/rois.py14
-rw-r--r--silx/gui/plot/tools/roi.py239
-rw-r--r--silx/gui/plot/tools/test/testROI.py67
-rw-r--r--silx/gui/plot3d/ScalarFieldView.py6
-rw-r--r--silx/gui/plot3d/items/_pick.py4
-rw-r--r--silx/gui/plot3d/items/core.py54
-rw-r--r--silx/gui/plot3d/items/mixins.py1
-rw-r--r--silx/gui/plot3d/items/volume.py2
-rw-r--r--silx/gui/plot3d/scene/cutplane.py4
-rw-r--r--silx/gui/plot3d/scene/function.py75
-rw-r--r--silx/gui/plot3d/scene/primitives.py10
-rw-r--r--silx/gui/plot3d/scene/text.py3
-rw-r--r--silx/gui/plot3d/scene/transform.py65
-rw-r--r--silx/gui/plot3d/scene/utils.py4
-rw-r--r--silx/gui/plot3d/test/testStatsWidget.py3
-rwxr-xr-xsilx/gui/test/test_colors.py51
-rw-r--r--silx/gui/utils/glutils.py7
-rw-r--r--silx/gui/utils/matplotlib.py71
-rw-r--r--silx/gui/utils/signal.py141
-rw-r--r--silx/gui/utils/testutils.py2
-rw-r--r--silx/gui/widgets/ElidedLabel.py4
-rw-r--r--silx/gui/widgets/test/__init__.py4
-rw-r--r--silx/gui/widgets/test/test_legendiconwidget.py74
-rw-r--r--silx/image/marchingsquares/_mergeimpl.pyx4
-rw-r--r--silx/image/tomography.py2
-rw-r--r--silx/io/commonh5.py22
-rw-r--r--silx/io/dictdump.py421
-rwxr-xr-xsilx/io/fabioh5.py10
-rw-r--r--silx/io/nxdata/parse.py4
-rw-r--r--silx/io/setup.py2
-rw-r--r--silx/io/specfile/src/locale_management.c5
-rw-r--r--silx/io/test/test_dictdump.py257
-rw-r--r--silx/io/test/test_spectoh5.py3
-rw-r--r--silx/io/test/test_url.py10
-rw-r--r--silx/io/test/test_utils.py244
-rw-r--r--silx/io/url.py21
-rw-r--r--silx/io/utils.py331
-rw-r--r--silx/math/colormap.pyx24
-rw-r--r--silx/math/fft/test/test_fft.py8
-rw-r--r--silx/math/fit/bgtheories.py8
-rw-r--r--silx/math/fit/fitmanager.py16
-rw-r--r--silx/math/fit/fittheories.py34
-rw-r--r--silx/math/fit/functions.pyx4
-rw-r--r--silx/math/fit/leastsq.py30
-rw-r--r--silx/math/fit/test/test_fit.py8
-rw-r--r--silx/math/fit/test/test_fitmanager.py12
-rw-r--r--silx/opencl/backprojection.py33
-rw-r--r--silx/opencl/common.py90
-rw-r--r--silx/opencl/convolution.py11
-rw-r--r--silx/opencl/processing.py54
-rw-r--r--silx/opencl/projection.py33
-rw-r--r--silx/opencl/test/test_addition.py28
-rw-r--r--silx/opencl/test/test_backprojection.py3
-rw-r--r--silx/opencl/test/test_convolution.py99
-rw-r--r--silx/resources/gui/icons/add.pngbin0 -> 470 bytes
-rw-r--r--silx/resources/gui/icons/add.svg2
-rw-r--r--silx/resources/gui/icons/backend-opengl.pngbin0 -> 1582 bytes
-rw-r--r--silx/resources/gui/icons/backend-opengl.svg18
-rw-r--r--silx/resources/gui/icons/rm.pngbin0 -> 348 bytes
-rw-r--r--silx/resources/gui/icons/rm.svg2
-rw-r--r--silx/resources/opencl/backproj.cl301
-rw-r--r--silx/resources/opencl/proj.cl4
-rw-r--r--silx/sx/_plot.py4
-rw-r--r--silx/utils/_have_openmp.pxd (renamed from silx/utils/_have_openmp.pxi)0
138 files changed, 8055 insertions, 3531 deletions
diff --git a/silx/app/test/test_convert.py b/silx/app/test/test_convert.py
index bb1ae99..857f30c 100644
--- a/silx/app/test/test_convert.py
+++ b/silx/app/test/test_convert.py
@@ -40,7 +40,7 @@ import h5py
import silx
from .. import convert
from silx.utils import testutils
-
+from silx.io.utils import h5py_read_dataset
# content of a spec file
@@ -137,7 +137,7 @@ class TestConvertCommand(unittest.TestCase):
self.assertTrue(os.path.isfile(h5name))
with h5py.File(h5name, "r") as h5f:
- title12 = h5f["/1.2/title"][()]
+ title12 = h5py_read_dataset(h5f["/1.2/title"])
if sys.version_info < (3, ):
title12 = title12.encode("utf-8")
self.assertEqual(title12,
diff --git a/silx/app/view/Viewer.py b/silx/app/view/Viewer.py
index 9503533..dd4d075 100644
--- a/silx/app/view/Viewer.py
+++ b/silx/app/view/Viewer.py
@@ -116,6 +116,8 @@ class Viewer(qt.QMainWindow):
spliter.addWidget(rightPanel)
spliter.addWidget(self.__dataPanel)
spliter.setStretchFactor(1, 1)
+ spliter.setCollapsible(0, False)
+ spliter.setCollapsible(1, False)
self.__splitter = spliter
main_panel = qt.QWidget(self)
diff --git a/silx/app/view/main.py b/silx/app/view/main.py
index c7afc19..a1369c1 100644
--- a/silx/app/view/main.py
+++ b/silx/app/view/main.py
@@ -1,6 +1,6 @@
# coding: utf-8
# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -57,7 +57,7 @@ def createParser():
default=False,
help='Use OpenGL for plots (instead of matplotlib)')
parser.add_argument(
- '--fresh',
+ '-f', '--fresh',
dest="fresh_preferences",
action="store_true",
default=False,
@@ -104,7 +104,7 @@ def mainQt(options):
from silx.gui import qt
# Make sure matplotlib is configured
# Needed for Debian 8: compatibility between Qt4/Qt5 and old matplotlib
- from silx.gui.plot import matplotlib
+ import silx.gui.utils.matplotlib # noqa
app = qt.QApplication([])
qt.QLocale.setDefault(qt.QLocale.c())
diff --git a/silx/gui/_glutils/FramebufferTexture.py b/silx/gui/_glutils/FramebufferTexture.py
index cc05080..e065030 100644
--- a/silx/gui/_glutils/FramebufferTexture.py
+++ b/silx/gui/_glutils/FramebufferTexture.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -62,6 +62,7 @@ class FramebufferTexture(object):
**kwargs):
self._texture = Texture(internalFormat, shape=shape, **kwargs)
+ self._texture.prepare()
self._previousFramebuffer = 0 # Used by with statement
diff --git a/silx/gui/_glutils/OpenGLWidget.py b/silx/gui/_glutils/OpenGLWidget.py
index 1f7bfae..5e3fcb8 100644
--- a/silx/gui/_glutils/OpenGLWidget.py
+++ b/silx/gui/_glutils/OpenGLWidget.py
@@ -329,6 +329,20 @@ class OpenGLWidget(qt.QWidget):
else:
return self.__openGLWidget.getDevicePixelRatio()
+ def getDotsPerInch(self):
+ """Returns current screen resolution as device pixels per inch.
+
+ :rtype: float
+ """
+ screen = self.window().windowHandle().screen()
+ if screen is not None:
+ # TODO check if this is correct on different OS/screen
+ # OK on macOS10.12/qt5.13.2
+ dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio()
+ else: # Fallback
+ dpi = 96. * self.getDevicePixelRatio()
+ return dpi
+
def getOpenGLVersion(self):
"""Returns the available OpenGL version.
diff --git a/silx/gui/_glutils/Texture.py b/silx/gui/_glutils/Texture.py
index a7fd44b..c72135a 100644
--- a/silx/gui/_glutils/Texture.py
+++ b/silx/gui/_glutils/Texture.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -81,20 +81,23 @@ class Texture(object):
else:
shape = data.shape
+ self._deferredUpdates = [(format_, data, None)]
+
assert len(shape) in (2, 3)
self._shape = tuple(shape)
self._ndim = len(shape)
self.texUnit = texUnit
- self._name = gl.glGenTextures(1)
- self.bind(self.texUnit)
+ self._texParameterUpdates = {} # Store texture params to update
+
+ self._minFilter = minFilter if minFilter is not None else gl.GL_NEAREST
+ self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = self._minFilter
- self._minFilter = None
- self.minFilter = minFilter if minFilter is not None else gl.GL_NEAREST
+ self._magFilter = magFilter if magFilter is not None else gl.GL_LINEAR
+ self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = self._magFilter
- self._magFilter = None
- self.magFilter = magFilter if magFilter is not None else gl.GL_LINEAR
+ self._name = None # Store texture ID
if wrap is not None:
if not isinstance(wrap, abc.Iterable):
@@ -102,69 +105,10 @@ class Texture(object):
assert len(wrap) == self.ndim
- gl.glTexParameter(self.target,
- gl.GL_TEXTURE_WRAP_S,
- wrap[-1])
- gl.glTexParameter(self.target,
- gl.GL_TEXTURE_WRAP_T,
- wrap[-2])
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_S] = wrap[-1]
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_T] = wrap[-2]
if self.ndim == 3:
- gl.glTexParameter(self.target,
- gl.GL_TEXTURE_WRAP_R,
- wrap[0])
-
- gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
-
- # This are the defaults, useless to set if not modified
- # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
-
- if data is None:
- data = c_void_p(0)
- type_ = gl.GL_UNSIGNED_BYTE
- else:
- type_ = utils.numpyToGLType(data.dtype)
-
- if self.ndim == 2:
- _logger.debug(
- 'Creating 2D texture shape: (%d, %d),'
- ' internal format: %s, format: %s, type: %s',
- self.shape[0], self.shape[1],
- str(self.internalFormat), str(format_), str(type_))
-
- gl.glTexImage2D(
- gl.GL_TEXTURE_2D,
- 0,
- self.internalFormat,
- self.shape[1],
- self.shape[0],
- 0,
- format_,
- type_,
- data)
- else:
- _logger.debug(
- 'Creating 3D texture shape: (%d, %d, %d),'
- ' internal format: %s, format: %s, type: %s',
- self.shape[0], self.shape[1], self.shape[2],
- str(self.internalFormat), str(format_), str(type_))
-
- gl.glTexImage3D(
- gl.GL_TEXTURE_3D,
- 0,
- self.internalFormat,
- self.shape[2],
- self.shape[1],
- self.shape[0],
- 0,
- format_,
- type_,
- data)
-
- gl.glBindTexture(self.target, 0)
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_R] = wrap[0]
@property
def target(self):
@@ -188,12 +132,11 @@ class Texture(object):
@property
def name(self):
- """OpenGL texture name"""
- if self._name is not None:
- return self._name
- else:
- raise RuntimeError(
- "No OpenGL texture resource, discard has already been called")
+ """OpenGL texture name.
+
+ It is None if not initialized or already discarded.
+ """
+ return self._name
@property
def minFilter(self):
@@ -204,10 +147,7 @@ class Texture(object):
def minFilter(self, minFilter):
if minFilter != self.minFilter:
self._minFilter = minFilter
- self.bind()
- gl.glTexParameter(self.target,
- gl.GL_TEXTURE_MIN_FILTER,
- self.minFilter)
+ self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = minFilter
@property
def magFilter(self):
@@ -218,20 +158,112 @@ class Texture(object):
def magFilter(self, magFilter):
if magFilter != self.magFilter:
self._magFilter = magFilter
- self.bind()
- gl.glTexParameter(self.target,
- gl.GL_TEXTURE_MAG_FILTER,
- self.magFilter)
+ self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = magFilter
- def discard(self):
- """Delete associated OpenGL texture"""
- if self._name is not None:
- gl.glDeleteTextures(self._name)
- self._name = None
- else:
- _logger.warning("Discard as already been called")
+ def _isPrepareRequired(self) -> bool:
+ """Returns True if OpenGL texture needs to be updated.
- def bind(self, texUnit=None):
+ :rtype: bool
+ """
+ return (self._name is None or
+ self._texParameterUpdates or
+ self._deferredUpdates)
+
+ def _prepareAndBind(self, texUnit=None):
+ """Synchronizes the OpenGL texture"""
+ if self._name is None:
+ self._name = gl.glGenTextures(1)
+
+ self._bind(texUnit)
+
+ # Synchronizes texture parameters
+ for pname, param in self._texParameterUpdates.items():
+ gl.glTexParameter(self.target, pname, param)
+ self._texParameterUpdates = {}
+
+ # Copy data to texture
+ for format_, data, offset in self._deferredUpdates:
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+
+ # This are the defaults, useless to set if not modified
+ # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
+
+ if data is None:
+ data = c_void_p(0)
+ type_ = gl.GL_UNSIGNED_BYTE
+ else:
+ type_ = utils.numpyToGLType(data.dtype)
+
+ if offset is None: # Initialize texture
+ if self.ndim == 2:
+ _logger.debug(
+ 'Creating 2D texture shape: (%d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage2D(
+ gl.GL_TEXTURE_2D,
+ 0,
+ self.internalFormat,
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+
+ else:
+ _logger.debug(
+ 'Creating 3D texture shape: (%d, %d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1], self.shape[2],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage3D(
+ gl.GL_TEXTURE_3D,
+ 0,
+ self.internalFormat,
+ self.shape[2],
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+
+ else: # Update already existing texture
+ if self.ndim == 2:
+ gl.glTexSubImage2D(gl.GL_TEXTURE_2D,
+ 0,
+ offset[1],
+ offset[0],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+
+ else:
+ gl.glTexSubImage3D(gl.GL_TEXTURE_3D,
+ 0,
+ offset[2],
+ offset[1],
+ offset[0],
+ data.shape[2],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+
+ self._deferredUpdates = []
+
+ def _bind(self, texUnit=None):
"""Bind the texture to a texture unit.
:param int texUnit: The texture unit to use
@@ -241,73 +273,80 @@ class Texture(object):
gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit)
gl.glBindTexture(self.target, self.name)
+ def _unbind(self, texUnit=None):
+ """Reset texture binding to a texture unit.
+
+ :param int texUnit: The texture unit to use
+ """
+ if texUnit is None:
+ texUnit = self.texUnit
+ gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit)
+ gl.glBindTexture(self.target, 0)
+
+ def prepare(self):
+ """Synchronizes the OpenGL texture.
+
+ This method must be called with a current OpenGL context.
+ """
+ if self._isPrepareRequired():
+ self._prepareAndBind()
+ self._unbind()
+
+ def bind(self, texUnit=None):
+ """Bind the texture to a texture unit.
+
+ The OpenGL texture is updated if needed.
+
+ This method must be called with a current OpenGL context.
+
+ :param int texUnit: The texture unit to use
+ """
+ if self._isPrepareRequired():
+ self._prepareAndBind(texUnit)
+ else:
+ self._bind(texUnit)
+
+ def discard(self):
+ """Delete associated OpenGL texture.
+
+ This method must be called with a current OpenGL context.
+ """
+ if self._name is not None:
+ gl.glDeleteTextures(self._name)
+ self._name = None
+ else:
+ _logger.warning("Texture not initialized or already discarded")
+
# with statement
def __enter__(self):
self.bind()
def __exit__(self, exc_type, exc_val, exc_tb):
- gl.glActiveTexture(gl.GL_TEXTURE0 + self.texUnit)
- gl.glBindTexture(self.target, 0)
+ self._unbind()
- def update(self,
- format_,
- data,
- offset=(0, 0, 0),
- texUnit=None):
+ def update(self, format_, data, offset=(0, 0, 0), copy=True):
"""Update the content of the texture.
Texture is not resized, so data must fit into texture with the
given offset.
+ This update is performed lazily during next call to
+ :meth:`prepare` or :meth:`bind`.
+ Data MUST not be changed until then.
+
:param format_: The OpenGL format of the data
:param data: The data to use to update the texture
- :param offset: The offset in the texture where to copy the data
- :type offset: List[int]
- :param int texUnit:
- The texture unit to use (default: the one provided at init)
+ :param List[int] offset: Offset in the texture where to copy the data
+ :param bool copy:
+ True (default) to copy data, False to use as is (do not modify)
"""
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=copy, order='C')
+ offset = tuple(offset)
assert data.ndim == self.ndim
assert len(offset) >= self.ndim
for i in range(self.ndim):
assert offset[i] + data.shape[i] <= self.shape[i]
- gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
-
- # This are the defaults, useless to set if not modified
- # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
- # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
-
- self.bind(texUnit)
-
- type_ = utils.numpyToGLType(data.dtype)
-
- if self.ndim == 2:
- gl.glTexSubImage2D(gl.GL_TEXTURE_2D,
- 0,
- offset[1],
- offset[0],
- data.shape[1],
- data.shape[0],
- format_,
- type_,
- data)
- gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
- else:
- gl.glTexSubImage3D(gl.GL_TEXTURE_3D,
- 0,
- offset[2],
- offset[1],
- offset[0],
- data.shape[2],
- data.shape[1],
- data.shape[0],
- format_,
- type_,
- data)
- gl.glBindTexture(gl.GL_TEXTURE_3D, 0)
+ self._deferredUpdates.append((format_, data, offset))
diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py
index 35cf819..d5627ef 100644
--- a/silx/gui/_glutils/utils.py
+++ b/silx/gui/_glutils/utils.py
@@ -29,45 +29,25 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "10/01/2017"
-from . import gl
import numpy
-
-_GL_TYPE_SIZES = {
- gl.GL_FLOAT: 4,
- gl.GL_BYTE: 1,
- gl.GL_SHORT: 2,
- gl.GL_INT: 4,
- gl.GL_UNSIGNED_BYTE: 1,
- gl.GL_UNSIGNED_SHORT: 2,
- gl.GL_UNSIGNED_INT: 4,
-}
+from OpenGL.constants import BYTE_SIZES as _BYTE_SIZES
+from OpenGL.constants import ARRAY_TO_GL_TYPE_MAPPING as _ARRAY_TO_GL_TYPE_MAPPING
def sizeofGLType(type_):
"""Returns the size in bytes of an element of type `type_`"""
- return _GL_TYPE_SIZES[type_]
-
-
-_TYPE_CONVERTER = {
- numpy.dtype(numpy.float32): gl.GL_FLOAT,
- numpy.dtype(numpy.int8): gl.GL_BYTE,
- numpy.dtype(numpy.int16): gl.GL_SHORT,
- numpy.dtype(numpy.int32): gl.GL_INT,
- numpy.dtype(numpy.uint8): gl.GL_UNSIGNED_BYTE,
- numpy.dtype(numpy.uint16): gl.GL_UNSIGNED_SHORT,
- numpy.dtype(numpy.uint32): gl.GL_UNSIGNED_INT,
-}
+ return _BYTE_SIZES[type_]
def isSupportedGLType(type_):
"""Test if a numpy type or dtype can be converted to a GL type."""
- return numpy.dtype(type_) in _TYPE_CONVERTER
+ return numpy.dtype(type_).char in _ARRAY_TO_GL_TYPE_MAPPING
def numpyToGLType(type_):
"""Returns the GL type corresponding the provided numpy type or dtype."""
- return _TYPE_CONVERTER[numpy.dtype(type_)]
+ return _ARRAY_TO_GL_TYPE_MAPPING[numpy.dtype(type_).char]
def segmentTrianglesIntersection(segment, triangles):
diff --git a/silx/gui/colors.py b/silx/gui/colors.py
index 4d750ba..4a96ae0 100755
--- a/silx/gui/colors.py
+++ b/silx/gui/colors.py
@@ -34,7 +34,10 @@ __date__ = "29/01/2019"
import numpy
import logging
import collections
+import warnings
+
from silx.gui import qt
+from silx.gui.utils import blockSignals
from silx.math.combo import min_max
from silx.math import colormap as _colormap
from silx.utils.exceptions import NotEditableError
@@ -45,10 +48,13 @@ from silx.resources import resource_filename as _resource_filename
_logger = logging.getLogger(__file__)
try:
+ import silx.gui.utils.matplotlib # noqa Initalize matplotlib
from matplotlib import cm as _matplotlib_cm
+ from matplotlib.pyplot import colormaps as _matplotlib_colormaps
except ImportError:
_logger.info("matplotlib not available, only embedded colormaps available")
_matplotlib_cm = None
+ _matplotlib_colormaps = None
_COLORDICT = {}
@@ -362,7 +368,22 @@ class _NormalizationMixIn:
if mode == Colormap.MINMAX:
vmin, vmax = self.autoscaleMinMax(data)
elif mode == Colormap.STDDEV3:
- vmin, vmax = self.autoscaleMean3Std(data)
+ dmin, dmax = self.autoscaleMinMax(data)
+ stdmin, stdmax = self.autoscaleMean3Std(data)
+ if dmin is None:
+ vmin = stdmin
+ elif stdmin is None:
+ vmin = dmin
+ else:
+ vmin = max(dmin, stdmin)
+
+ if dmax is None:
+ vmax = stdmax
+ elif stdmax is None:
+ vmax = dmax
+ else:
+ vmax = min(dmax, stdmax)
+
else:
raise ValueError('Unsupported mode: %s' % mode)
@@ -405,7 +426,13 @@ class _NormalizationMixIn:
normdata[numpy.isfinite(normdata) == False] = numpy.nan
if normdata.size == 0: # Fallback
return None, None
- mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore nanmean "Mean of empty slice" warning and
+ # nanstd "Degrees of freedom <= 0 for slice" warning
+ mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata)
+
return self.revert(mean - 3 * std, 0., 1.), self.revert(mean + 3 * std, 0., 1.)
@@ -426,7 +453,11 @@ class _LinearNormalizationMixIn(_NormalizationMixIn):
data[numpy.isfinite(data) == False] = numpy.nan
if data.size == 0: # Fallback
return None, None
- mean, std = numpy.nanmean(data), numpy.nanstd(data)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore nanmean "Mean of empty slice" warning and
+ # nanstd "Degrees of freedom <= 0 for slice" warning
+ mean, std = numpy.nanmean(data), numpy.nanstd(data)
return mean - 3 * std, mean + 3 * std
@@ -534,7 +565,8 @@ class Colormap(qt.QObject):
"""constant for autoscale using min/max data range"""
STDDEV3 = 'stddev3'
- """constant for autoscale using mean +/- 3*std(data)"""
+ """constant for autoscale using mean +/- 3*std(data)
+ with a clamp on min/max of the data"""
AUTOSCALE_MODES = (MINMAX, STDDEV3)
"""Tuple of managed auto scale algorithms"""
@@ -542,10 +574,14 @@ class Colormap(qt.QObject):
sigChanged = qt.Signal()
"""Signal emitted when the colormap has changed."""
+ _DEFAULT_NAN_COLOR = 255, 255, 255, 0
+
def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX):
qt.QObject.__init__(self)
self._editable = True
self.__gamma = 2.0
+ # Default NaN color: fully transparent white
+ self.__nanColor = numpy.array(self._DEFAULT_NAN_COLOR, dtype=numpy.uint8)
assert normalization in Colormap.NORMALIZATIONS
assert autoscaleMode in Colormap.AUTOSCALE_MODES
@@ -593,15 +629,19 @@ class Colormap(qt.QObject):
raise NotEditableError('Colormap is not editable')
if self == other:
return
- old = self.blockSignals(True)
- name = other.getName()
- if name is not None:
- self.setName(name)
- else:
- self.setColormapLUT(other.getColormapLUT())
- self.setNormalization(other.getNormalization())
- self.setVRange(other.getVMin(), other.getVMax())
- self.blockSignals(old)
+ with blockSignals(self):
+ name = other.getName()
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(other.getColormapLUT())
+ self.setNaNColor(other.getNaNColor())
+ self.setNormalization(other.getNormalization())
+ self.setGammaNormalizationParameter(
+ other.getGammaNormalizationParameter())
+ self.setAutoscaleMode(other.getAutoscaleMode())
+ self.setVRange(*other.getVRange())
+ self.setEditable(other.isEditable())
self.sigChanged.emit()
def getNColors(self, nbColors=None):
@@ -623,7 +663,7 @@ class Colormap(qt.QObject):
colormap.setNormalization(Colormap.LINEAR)
colormap.setVRange(vmin=0, vmax=nbColors - 1)
colors = colormap.applyToData(
- numpy.arange(nbColors, dtype=numpy.int))
+ numpy.arange(nbColors, dtype=numpy.int32))
return colors
def getName(self):
@@ -689,6 +729,24 @@ class Colormap(qt.QObject):
self._name = None
self.sigChanged.emit()
+ def getNaNColor(self):
+ """Returns the color to use for Not-A-Number floating point value.
+
+ :rtype: QColor
+ """
+ return qt.QColor(*self.__nanColor)
+
+ def setNaNColor(self, color):
+ """Set the color to use for Not-A-Number floating point value.
+
+ :param color: RGB(A) color to use for NaN values
+ :type color: QColor, str, tuple of uint8 or float in [0., 1.]
+ """
+ color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8)
+ if not numpy.array_equal(self.__nanColor, color):
+ self.__nanColor = color
+ self.sigChanged.emit()
+
def getNormalization(self):
"""Return the normalization of the colormap.
@@ -1021,8 +1079,10 @@ class Colormap(qt.QObject):
vmax=self._vmax,
normalization=self.getNormalization(),
autoscaleMode=self.getAutoscaleMode())
+ colormap.setNaNColor(self.getNaNColor())
colormap.setGammaNormalizationParameter(
self.getGammaNormalizationParameter())
+ colormap.setEditable(self.isEditable())
return colormap
def applyToData(self, data, reference=None):
@@ -1038,10 +1098,15 @@ class Colormap(qt.QObject):
vmin, vmax = self.getColormapRange(reference)
if hasattr(data, "getColormappedData"): # Use item's data
- data = data.getColormappedData()
+ data = data.getColormappedData(copy=False)
return _colormap.cmap(
- data, self._colors, vmin, vmax, self._getNormalizer())
+ data,
+ self._colors,
+ vmin,
+ vmax,
+ self._getNormalizer(),
+ self.__nanColor)
@staticmethod
def getSupportedColormaps():
@@ -1055,8 +1120,8 @@ class Colormap(qt.QObject):
:rtype: tuple
"""
colormaps = set()
- if _matplotlib_cm is not None:
- colormaps.update(_matplotlib_cm.cmap_d.keys())
+ if _matplotlib_colormaps is not None:
+ colormaps.update(_matplotlib_colormaps())
colormaps.update(_AVAILABLE_LUTS.keys())
colormaps = tuple(cmap for cmap in sorted(colormaps)
@@ -1086,7 +1151,7 @@ class Colormap(qt.QObject):
numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
)
- _SERIAL_VERSION = 2
+ _SERIAL_VERSION = 3
def restoreState(self, byteArray):
"""
@@ -1106,7 +1171,7 @@ class Colormap(qt.QObject):
return False
version = stream.readUInt32()
- if version not in (1, self._SERIAL_VERSION):
+ if version not in numpy.arange(1, self._SERIAL_VERSION+1):
_logger.warning("Serial version mismatch. Found %d." % version)
return False
@@ -1133,6 +1198,11 @@ class Colormap(qt.QObject):
else:
autoscaleMode = stream.readQString()
+ if version <= 2:
+ nanColor = self._DEFAULT_NAN_COLOR
+ else:
+ nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32()
+
# emit change event only once
old = self.blockSignals(True)
try:
@@ -1142,6 +1212,7 @@ class Colormap(qt.QObject):
self.setVRange(vmin, vmax)
if gamma is not None:
self.setGammaNormalizationParameter(gamma)
+ self.setNaNColor(nanColor)
finally:
self.blockSignals(old)
self.sigChanged.emit()
@@ -1169,6 +1240,12 @@ class Colormap(qt.QObject):
if self.getNormalization() == Colormap.GAMMA:
stream.writeFloat(self.getGammaNormalizationParameter())
stream.writeQString(self.getAutoscaleMode())
+ nanColor = self.getNaNColor()
+ stream.writeInt32(nanColor.red())
+ stream.writeInt32(nanColor.green())
+ stream.writeInt32(nanColor.blue())
+ stream.writeInt32(nanColor.alpha())
+
return data
diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py
index f3b02b9..d9958de 100644
--- a/silx/gui/data/DataViews.py
+++ b/silx/gui/data/DataViews.py
@@ -406,7 +406,7 @@ class DataView(object):
:param NamedTuple selection: Data selected
:rtype: str
"""
- if selection is None:
+ if selection is None or selection.filename is None:
return None
else:
directory, filename = os.path.split(selection.filename)
diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py
index 57d6f7b..7749326 100644
--- a/silx/gui/data/Hdf5TableView.py
+++ b/silx/gui/data/Hdf5TableView.py
@@ -380,37 +380,87 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
SEPARATOR = "::"
self.__data.addHeaderRow(headerLabel="Path info")
+ showPhysicalLocation = True
if isinstance(obj, silx.gui.hdf5.H5Node):
# helpful informations if the object come from an HDF5 tree
self.__data.addHeaderValueRow("Basename", lambda x: x.local_basename)
self.__data.addHeaderValueRow("Name", lambda x: x.local_name)
local = lambda x: x.local_filename + SEPARATOR + x.local_name
self.__data.addHeaderValueRow("Local", local)
- physical = lambda x: x.physical_filename + SEPARATOR + x.physical_name
- self.__data.addHeaderValueRow("Physical", physical)
else:
# it's a real H5py object
self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name))
self.__data.addHeaderValueRow("Name", lambda x: x.name)
if obj.file is not None:
self.__data.addHeaderValueRow("File", lambda x: x.file.filename)
-
if hasattr(obj, "path"):
# That's a link
if hasattr(obj, "filename"):
+ # External link
link = lambda x: x.filename + SEPARATOR + x.path
else:
+ # Soft link
link = lambda x: x.path
self.__data.addHeaderValueRow("Link", link)
- else:
- if silx.io.is_file(obj):
- physical = lambda x: x.filename + SEPARATOR + x.name
+ showPhysicalLocation = False
+
+ # External data (nothing to do with external links)
+ nExtSources = 0
+ firstExtSource = None
+ extType = None
+ if silx.io.is_dataset(hdf5obj):
+ if hasattr(hdf5obj, "is_virtual"):
+ if hdf5obj.is_virtual:
+ extSources = hdf5obj.virtual_sources()
+ if extSources:
+ firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name
+ extType = "Virtual"
+ nExtSources = len(extSources)
+ if hasattr(hdf5obj, "external"):
+ extSources = hdf5obj.external
+ if extSources:
+ firstExtSource = extSources[0][0]
+ extType = "Raw"
+ nExtSources = len(extSources)
+
+ if showPhysicalLocation:
+ def _physical_location(x):
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ return x.physical_filename + SEPARATOR + x.physical_name
+ elif silx.io.is_file(obj):
+ return x.filename + SEPARATOR + x.name
elif obj.file is not None:
- physical = lambda x: x.file.filename + SEPARATOR + x.name
+ return x.file.filename + SEPARATOR + x.name
else:
# Guess it is a virtual node
- physical = "No physical location"
- self.__data.addHeaderValueRow("Physical", physical)
+ return "No physical location"
+
+ self.__data.addHeaderValueRow("Physical", _physical_location)
+
+ if extType:
+ def _first_source(x):
+ # Absolute path
+ if os.path.isabs(firstExtSource):
+ return firstExtSource
+
+ # Relative path with respect to the file directory
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ filename = x.physical_filename
+ elif silx.io.is_file(obj):
+ filename = x.filename
+ elif obj.file is not None:
+ filename = x.file.filename
+ else:
+ return firstExtSource
+
+ if firstExtSource[0] == ".":
+ firstExtSource.pop(0)
+ return os.path.join(os.path.dirname(filename), firstExtSource)
+
+ self.__data.addHeaderRow(headerLabel="External sources")
+ self.__data.addHeaderValueRow("Type", extType)
+ self.__data.addHeaderValueRow("Count", str(nExtSources))
+ self.__data.addHeaderValueRow("First", _first_source)
if hasattr(obj, "dtype"):
diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py
index 224f337..271b267 100644
--- a/silx/gui/data/NXdataWidgets.py
+++ b/silx/gui/data/NXdataWidgets.py
@@ -370,6 +370,7 @@ class ArrayImagePlot(qt.QWidget):
vmin=None, vmax=None,
normalization=Colormap.LINEAR))
self._plot.getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
# not closable
self._selector = NumpyAxesSelector(self)
diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py
index 98c37d7..8fd7c7c 100644
--- a/silx/gui/data/TextFormatter.py
+++ b/silx/gui/data/TextFormatter.py
@@ -267,6 +267,12 @@ class TextFormatter(qt.QObject):
if vlen is not None:
if vlen == six.text_type:
# HDF5 UTF8
+ # With h5py>=3 reading dataset returns bytes
+ if isinstance(data, (bytes, numpy.bytes_)):
+ try:
+ data = data.decode("utf-8")
+ except UnicodeDecodeError:
+ self.__formatSafeAscii(data)
return self.__formatText(data)
elif vlen == six.binary_type:
# HDF5 ASCII
@@ -289,7 +295,7 @@ class TextFormatter(qt.QObject):
elif isinstance(data, list):
text = [self.toString(d) for d in data]
return "[" + " ".join(text) + "]"
- elif isinstance(data, (numpy.ndarray)):
+ elif isinstance(data, numpy.ndarray):
if dtype is None:
dtype = data.dtype
if data.shape == ():
diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py
index 12a640e..dd01dd6 100644
--- a/silx/gui/data/test/test_dataviewer.py
+++ b/silx/gui/data/test/test_dataviewer.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -108,7 +108,7 @@ class AbstractDataViewerTests(TestCaseQt):
self.assertIn(DataViews.IMAGE_MODE, availableModes)
def test_image_bool(self):
- data = numpy.zeros((10, 10), dtype=numpy.bool)
+ data = numpy.zeros((10, 10), dtype=bool)
data[::2, ::2] = True
widget = self.create_widget()
widget.setData(data)
@@ -117,7 +117,7 @@ class AbstractDataViewerTests(TestCaseQt):
self.assertIn(DataViews.IMAGE_MODE, availableModes)
def test_image_complex_data(self):
- data = numpy.arange(3 ** 2, dtype=numpy.complex)
+ data = numpy.arange(3 ** 2, dtype=numpy.complex64)
data.shape = [3] * 2
widget = self.create_widget()
widget.setData(data)
@@ -262,7 +262,7 @@ class TestDataView(TestCaseQt):
line = [1, 2j, 3 + 3j, 4]
image = [line, line, line, line]
cube = [image, image, image, image]
- data = numpy.array(cube, dtype=numpy.complex)
+ data = numpy.array(cube, dtype=numpy.complex64)
return data
def createDataViewWithData(self, dataViewClass, data):
diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py
index 1a63074..d3050bf 100644
--- a/silx/gui/data/test/test_textformatter.py
+++ b/silx/gui/data/test/test_textformatter.py
@@ -36,6 +36,7 @@ import six
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.utils.testutils import SignalListener
from ..TextFormatter import TextFormatter
+from silx.io.utils import h5py_read_dataset
import h5py
@@ -123,76 +124,79 @@ class TestTextFormatterWithH5py(TestCaseQt):
dataset = self.h5File.create_dataset(testName, data=data, dtype=dtype)
return dataset
+ def read_dataset(self, d):
+ return self.formatter.toString(d[()], dtype=d.dtype)
+
def testAscii(self):
d = self.create_dataset(data=b"abc")
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '"abc"')
def testUnicode(self):
d = self.create_dataset(data=u"i\u2661cookies")
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(len(result), 11)
self.assertEqual(result, u'"i\u2661cookies"')
def testBadAscii(self):
d = self.create_dataset(data=b"\xF0\x9F\x92\x94")
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, 'b"\\xF0\\x9F\\x92\\x94"')
def testVoid(self):
d = self.create_dataset(data=numpy.void(b"abc\xF0"))
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"')
def testEnum(self):
dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
d = numpy.array(42, dtype=dtype)
d = self.create_dataset(data=d)
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, 'BLUE(42)')
def testRef(self):
dtype = h5py.special_dtype(ref=h5py.Reference)
d = numpy.array(self.h5File.ref, dtype=dtype)
d = self.create_dataset(data=d)
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, 'REF')
def testArrayAscii(self):
d = self.create_dataset(data=[b"abc"])
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '["abc"]')
def testArrayUnicode(self):
dtype = h5py.special_dtype(vlen=six.text_type)
d = numpy.array([u"i\u2661cookies"], dtype=dtype)
d = self.create_dataset(data=d)
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(len(result), 13)
self.assertEqual(result, u'["i\u2661cookies"]')
def testArrayBadAscii(self):
d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"])
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '[b"\\xF0\\x9F\\x92\\x94"]')
def testArrayVoid(self):
d = self.create_dataset(data=numpy.void([b"abc\xF0"]))
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]')
def testArrayEnum(self):
dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
d = numpy.array([42, 1, 100], dtype=dtype)
d = self.create_dataset(data=d)
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '[BLUE(42) GREEN(1) 100]')
def testArrayRef(self):
dtype = h5py.special_dtype(ref=h5py.Reference)
d = numpy.array([self.h5File.ref, None], dtype=dtype)
d = self.create_dataset(data=d)
- result = self.formatter.toString(d[()], dtype=d.dtype)
+ result = self.read_dataset(d)
self.assertEqual(result, '[REF NULL_REF]')
diff --git a/silx/gui/fit/BackgroundWidget.py b/silx/gui/fit/BackgroundWidget.py
index 2171e87..76bc043 100644
--- a/silx/gui/fit/BackgroundWidget.py
+++ b/silx/gui/fit/BackgroundWidget.py
@@ -1,6 +1,6 @@
# coding: utf-8
#/*##########################################################################
-# Copyright (C) 2004-2017 V.A. Sole, European Synchrotron Radiation Facility
+# Copyright (C) 2004-2020 V.A. Sole, European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -337,7 +337,7 @@ class BackgroundWidget(qt.QWidget):
pars = self.getParameters()
# smoothed data
- y = numpy.ravel(numpy.array(self._y)).astype(numpy.float)
+ y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64)
if pars["SmoothingFlag"]:
ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
f = [0.25, 0.5, 0.25]
diff --git a/silx/gui/fit/FitWidget.py b/silx/gui/fit/FitWidget.py
index 7279cd9..08731f1 100644
--- a/silx/gui/fit/FitWidget.py
+++ b/silx/gui/fit/FitWidget.py
@@ -720,7 +720,7 @@ class FitWidget(qt.QWidget):
if __name__ == "__main__":
import numpy
- x = numpy.arange(1500).astype(numpy.float)
+ x = numpy.arange(1500).astype(numpy.float64)
constant_bg = 3.14
p = [1000, 100., 30.0,
diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py
index 11a08b6..e07f835 100755
--- a/silx/gui/hdf5/Hdf5Item.py
+++ b/silx/gui/hdf5/Hdf5Item.py
@@ -100,7 +100,7 @@ class Hdf5Item(Hdf5Node):
"""Returns the class of the stored object.
When the object is in lazy loading, this method should be able to
- return the type of the futrue loaded object. It allows to delay the
+ return the type of the future loaded object. It allows to delay the
real load of the object.
:rtype: silx.io.utils.H5Type
@@ -114,7 +114,7 @@ class Hdf5Item(Hdf5Node):
"""Returns the class of the stored object.
When the object is in lazy loading, this method should be able to
- return the type of the futrue loaded object. It allows to delay the
+ return the type of the future loaded object. It allows to delay the
real load of the object.
:rtype: h5py.File or h5py.Dataset or h5py.Group
@@ -383,12 +383,13 @@ class Hdf5Item(Hdf5Node):
text = text.strip('"')
# Check NX_class formatting
lower = text.lower()
+ formatedNX_class = ""
if lower.startswith('nx'):
formatedNX_class = 'NX' + lower[2:]
if lower == 'nxcansas':
formatedNX_class = 'NXcanSAS' # That's the only class with capital letters...
if text != formatedNX_class:
- _logger.error("NX_class: %s is malformed (should be %s)",
+ _logger.error("NX_class: '%s' is malformed (should be '%s')",
text,
formatedNX_class)
text = formatedNX_class
@@ -614,17 +615,28 @@ class Hdf5Item(Hdf5Node):
if role == qt.Qt.TextAlignmentRole:
return qt.Qt.AlignTop | qt.Qt.AlignLeft
if role == qt.Qt.DisplayRole:
+ # Mark as link
link = self.linkClass
if link is None:
- return ""
+ pass
+ elif link == silx.io.utils.H5Type.HARD_LINK:
+ pass
elif link == silx.io.utils.H5Type.EXTERNAL_LINK:
return "External"
elif link == silx.io.utils.H5Type.SOFT_LINK:
return "Soft"
- elif link == silx.io.utils.H5Type.HARD_LINK:
- return ""
else:
return link.__name__
+ # Mark as external data
+ if self.h5Class == silx.io.utils.H5Type.DATASET:
+ obj = self.obj
+ if hasattr(obj, "is_virtual"):
+ if obj.is_virtual:
+ return "Virtual"
+ if hasattr(obj, "external"):
+ if obj.external:
+ return "ExtRaw"
+ return ""
if role == qt.Qt.ToolTipRole:
return None
return None
diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py
index 5bd4223..fcfc02c 100755
--- a/silx/gui/hdf5/test/test_hdf5.py
+++ b/silx/gui/hdf5/test/test_hdf5.py
@@ -589,11 +589,11 @@ class TestNexusSortFilterProxyModel(TestCaseQt):
self.assertListEqual(names, ["100aaa", "aaa100"])
-class TestH5Node(TestCaseQt):
+class _TestModelBase(TestCaseQt):
@classmethod
def setUpClass(cls):
- super(TestH5Node, cls).setUpClass()
+ super(_TestModelBase, cls).setUpClass()
cls.tmpDirectory = tempfile.mkdtemp()
cls.h5Filename = cls.createResource(cls.tmpDirectory)
@@ -603,13 +603,18 @@ class TestH5Node(TestCaseQt):
@classmethod
def createResource(cls, directory):
filename = os.path.join(directory, "base.h5")
- externalFilename = os.path.join(directory, "base__external.h5")
+ extH5FileName = os.path.join(directory, "base__external.h5")
+ extDatFileName = os.path.join(directory, "base__external.dat")
- externalh5 = h5py.File(externalFilename, mode="w")
+ externalh5 = h5py.File(extH5FileName, mode="w")
externalh5["target/dataset"] = 50
externalh5["target/link"] = h5py.SoftLink("/target/dataset")
+ externalh5["/ext/vds0"] = [0, 1]
+ externalh5["/ext/vds1"] = [2, 3]
externalh5.close()
+ numpy.array([0,1,10,10,2,3]).tofile(extDatFileName)
+
h5 = h5py.File(filename, mode="w")
h5["group/dataset"] = 50
h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
@@ -617,12 +622,19 @@ class TestH5Node(TestCaseQt):
h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link")
h5["link/soft_link_to_file"] = h5py.SoftLink("/")
h5["group/soft_link_relative"] = h5py.SoftLink("dataset")
- h5["link/external_link"] = h5py.ExternalLink(externalFilename, "/target/dataset")
- h5["link/external_link_to_link"] = h5py.ExternalLink(externalFilename, "/target/link")
- h5["broken_link/external_broken_file"] = h5py.ExternalLink(externalFilename + "_not_exists", "/target/link")
- h5["broken_link/external_broken_link"] = h5py.ExternalLink(externalFilename, "/target/not_exists")
+ h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset")
+ h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link")
+ h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link")
+ h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists")
h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists")
h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists")
+ layout = h5py.VirtualLayout((2,2), dtype=int)
+ layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int)
+ layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int)
+ h5.create_group("/ext")
+ h5["/ext"].create_virtual_dataset("virtual", layout)
+ external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)]
+ h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external)
h5.close()
return filename
@@ -640,7 +652,7 @@ class TestH5Node(TestCaseQt):
cls.qWaitForDestroy(ref)
cls.h5File.close()
shutil.rmtree(cls.tmpDirectory)
- super(TestH5Node, cls).tearDownClass()
+ super(_TestModelBase, cls).tearDownClass()
def getIndexFromPath(self, model, path):
"""
@@ -658,9 +670,114 @@ class TestH5Node(TestCaseQt):
raise RuntimeError("Path not found")
return index
- def getH5NodeFromPath(self, model, path):
+ def getH5ItemFromPath(self, model, path):
index = self.getIndexFromPath(model, path)
- item = model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+ return model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+
+
+class TestH5Item(_TestModelBase):
+
+ def testFile(self):
+ path = ["base.h5"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testGroup(self):
+ path = ["base.h5", "group"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDataset(self):
+ path = ["base.h5", "group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testSoftLink(self):
+ path = ["base.h5", "link", "soft_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToLink(self):
+ path = ["base.h5", "link", "soft_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkRelative(self):
+ path = ["base.h5", "group", "soft_link_relative"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testExternalLink(self):
+ path = ["base.h5", "link", "external_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalLinkToLink(self):
+ path = ["base.h5", "link", "external_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenFile(self):
+ path = ["base.h5", "broken_link", "external_broken_file"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenLink(self):
+ path = ["base.h5", "broken_link", "external_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testSoftBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testDatasetFromSoftLinkToGroup(self):
+ path = ["base.h5", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDatasetFromSoftLinkToFile(self):
+ path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Virtual")
+
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "ExtRaw")
+
+
+class TestH5Node(_TestModelBase):
+
+ def getH5NodeFromPath(self, model, path):
+ item = self.getH5ItemFromPath(model, path)
h5node = hdf5.H5Node(item)
return h5node
@@ -824,6 +941,28 @@ class TestH5Node(TestCaseQt):
self.assertEqual(h5node.local_basename, "dataset")
self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset")
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "virtual")
+ self.assertEqual(h5node.physical_name, "/ext/virtual")
+ self.assertEqual(h5node.local_basename, "virtual")
+ self.assertEqual(h5node.local_name, "/ext/virtual")
+
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "raw")
+ self.assertEqual(h5node.physical_name, "/ext/raw")
+ self.assertEqual(h5node.local_basename, "raw")
+ self.assertEqual(h5node.local_name, "/ext/raw")
+
class TestHdf5TreeView(TestCaseQt):
"""Test to check that icons module."""
@@ -993,6 +1132,7 @@ def suite():
test_suite.addTest(loadTests(TestNexusSortFilterProxyModel))
test_suite.addTest(loadTests(TestHdf5TreeView))
test_suite.addTest(loadTests(TestH5Node))
+ test_suite.addTest(loadTests(TestH5Item))
return test_suite
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
index 2b4677b..eff7689 100644
--- a/silx/gui/plot/ColorBar.py
+++ b/silx/gui/plot/ColorBar.py
@@ -142,11 +142,8 @@ class ColorBarWidget(qt.QWidget):
self._isConnected = True
def setVisible(self, isVisible):
- # isHidden looks to be always synchronized, while isVisible is not
- wasHidden = self.isHidden()
qt.QWidget.setVisible(self, isVisible)
- if wasHidden != self.isHidden():
- self.sigVisibleChanged.emit(not self.isHidden())
+ self.sigVisibleChanged.emit(isVisible)
def showEvent(self, event):
self._connectPlot()
diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py
index cd891cc..dc6bf63 100644
--- a/silx/gui/plot/ComplexImageView.py
+++ b/silx/gui/plot/ComplexImageView.py
@@ -318,7 +318,7 @@ class ComplexImageView(qt.QWidget):
False to use provided data (do not modify!).
"""
if data is None:
- data = numpy.zeros((0, 0), dtype=numpy.complex)
+ data = numpy.zeros((0, 0), dtype=numpy.complex64)
previousData = self._plotImage.getComplexData(copy=False)
diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py
index 4865b8e..5c9033e 100644
--- a/silx/gui/plot/CurvesROIWidget.py
+++ b/silx/gui/plot/CurvesROIWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -1215,14 +1215,14 @@ class ROI(_RegionOfInterestBase):
if len(idx):
xw = x[idx]
yw = y[idx]
- rawCounts = yw.sum(dtype=numpy.float)
+ rawCounts = yw.sum(dtype=numpy.float64)
deltaX = xw[-1] - xw[0]
deltaY = yw[-1] - yw[0]
if deltaX > 0.0:
slope = (deltaY / deltaX)
background = yw[0] + slope * (xw - xw[0])
netCounts = (rawCounts -
- background.sum(dtype=numpy.float))
+ background.sum(dtype=numpy.float64))
else:
netCounts = 0.0
else:
diff --git a/silx/gui/plot/ImageStack.py b/silx/gui/plot/ImageStack.py
index c620d6d..3b652ca 100644
--- a/silx/gui/plot/ImageStack.py
+++ b/silx/gui/plot/ImageStack.py
@@ -150,7 +150,10 @@ class UrlList(qt.QWidget):
self._listWidget.addItems(url_names)
def _notifyCurrentUrlChanged(self, current, previous):
- self.sigCurrentUrlChanged.emit(current.text())
+ if current is None:
+ pass
+ else:
+ self.sigCurrentUrlChanged.emit(current.text())
def setUrl(self, url: DataUrl) -> None:
assert isinstance(url, DataUrl)
@@ -163,6 +166,9 @@ class UrlList(qt.QWidget):
self._listWidget.setCurrentItem(item)
self.sigCurrentUrlChanged.emit(item.text())
+ def clear(self):
+ self._listWidget.clear()
+
class _ToggleableUrlSelectionTable(qt.QWidget):
@@ -214,6 +220,9 @@ class _ToggleableUrlSelectionTable(qt.QWidget):
def _propagateSignal(self, url):
self.sigCurrentUrlChanged.emit(url)
+ def clear(self):
+ self._urlsTable.clear()
+
class UrlLoader(qt.QThread):
"""
@@ -326,6 +335,8 @@ class ImageStack(qt.QMainWindow):
self._urlData = OrderedDict({})
self._current_url = None
self._plot.clear()
+ self._urlsTable.clear()
+ self._slider.setMaximum(-1)
def _preFetch(self, urls: list) -> None:
"""Pre-fetch the given urls if necessary
@@ -414,14 +425,16 @@ class ImageStack(qt.QMainWindow):
self._urlsTable.blockSignals(old_url_table)
old_slider = self._slider.blockSignals(True)
+ self._slider.setMinimum(0)
self._slider.setMaximum(len(self._urls) - 1)
self._slider.blockSignals(old_slider)
if self.getCurrentUrl() in self._urls:
self.setCurrentUrl(self.getCurrentUrl())
else:
- first_url = self._urls[list(self._urls.keys())[0]]
- self.setCurrentUrl(first_url)
+ if len(self._urls.keys()) > 0:
+ first_url = self._urls[list(self._urls.keys())[0]]
+ self.setCurrentUrl(first_url)
def getUrls(self) -> tuple:
"""
@@ -516,7 +529,11 @@ class ImageStack(qt.QMainWindow):
:param index: url to be displayed
:type: int
"""
- if index >= len(self._urls):
+ if index < 0:
+ return
+ if self._urls is None:
+ return
+ elif index >= len(self._urls):
raise ValueError('requested index out of bounds')
else:
return self.setCurrentUrl(self._urls[index])
diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py
index fafd49f..8cc0cc6 100644
--- a/silx/gui/plot/ImageView.py
+++ b/silx/gui/plot/ImageView.py
@@ -56,7 +56,7 @@ from ..colors import Colormap
from ..colors import cursorColorForColormap
from .tools import LimitsToolBar
from .Profile import ProfileToolBar
-
+from ...utils.proxy import docstring
_logger = logging.getLogger(__name__)
@@ -341,6 +341,10 @@ class ImageView(PlotWindow):
self._radarView = RadarView(parent=self)
self._radarView.visibleRectDragged.connect(self._radarViewCB)
+ self.__setCentralWidget()
+
+ def __setCentralWidget(self):
+ """Set central widget with all its content"""
layout = qt.QGridLayout()
layout.addWidget(self.getWidgetHandle(), 0, 0)
layout.addWidget(self._histoVPlot.getWidgetHandle(), 0, 1)
@@ -365,6 +369,12 @@ class ImageView(PlotWindow):
centralWidget.setLayout(layout)
self.setCentralWidget(centralWidget)
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ # Use PlotWidget here since we override PlotWindow behavior
+ PlotWidget.setBackend(self, backend)
+ self.__setCentralWidget()
+
def _dirtyCache(self):
self._cache = None
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
index a95e277..8ff8641 100644
--- a/silx/gui/plot/MaskToolsWidget.py
+++ b/silx/gui/plot/MaskToolsWidget.py
@@ -116,7 +116,8 @@ class ImageMask(BaseMask):
"""
if kind == 'edf':
edfFile = EdfFile(filename, access="w+")
- edfFile.WriteImage({}, self.getMask(copy=False), Append=0)
+ header = {"program_name": "silx-mask", "masked_value": "nonzero"}
+ edfFile.WriteImage(header, self.getMask(copy=False), Append=0)
elif kind == 'tif':
tiffFile = TiffIO(filename, mode='w')
@@ -568,7 +569,9 @@ class MaskToolsWidget(BaseMaskToolsWidget):
filename = dialog.selectedFiles()[0]
dialog.close()
+ # Update the directory according to the user selection
self.maskFileDir = os.path.dirname(filename)
+
try:
self.load(filename)
except RuntimeWarning as e:
@@ -660,22 +663,35 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if os.path.exists(filename) and "HDF5" not in nameFilter:
try:
os.remove(filename)
- except IOError:
+ except IOError as e:
msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
msg.setText("Cannot save.\n"
- "Input Output Error: %s" % (sys.exc_info()[1]))
+ "Input Output Error: %s" % strerror)
msg.exec_()
return
+ # Update the directory according to the user selection
self.maskFileDir = os.path.dirname(filename)
+
try:
self.save(filename, extension[1:])
except Exception as e:
- raise
msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
msg.setIcon(qt.QMessageBox.Critical)
- msg.setText("Cannot save file %s\n%s" % (filename, e.args[0]))
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
msg.exec_()
def resetSelectionMask(self):
@@ -727,7 +743,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
# Convert from plot to array coords
center = (event['points'][0] - self._origin) / self._scale
size = event['points'][1] / self._scale
- center = center.astype(numpy.int) # (row, col)
+ center = center.astype(numpy.int64) # (row, col)
self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
self._mask.commit()
@@ -736,7 +752,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
doMask = self._isMasking()
# Convert from plot to array coords
vertices = (event['points'] - self._origin) / self._scale
- vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
+ vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
self._mask.updatePolygon(level, vertices, doMask)
self._mask.commit()
diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py
index d182a49..cfe140b 100644
--- a/silx/gui/plot/PlotInteraction.py
+++ b/silx/gui/plot/PlotInteraction.py
@@ -1604,6 +1604,8 @@ class DrawSelectMode(FocusManager):
def __init__(self, plot, shape, label, color, width):
eventHandlerClass = _DRAW_MODES[shape]
+ self._pan = Pan(plot)
+ self._panStart = None
parameters = {
'shape': shape,
'label': label,
@@ -1614,6 +1616,23 @@ class DrawSelectMode(FocusManager):
ItemsInteractionForCombo(plot),
eventHandlerClass(plot, parameters)))
+ def handleEvent(self, eventName, *args, **kwargs):
+ # Hack to add pan interaction to select-draw
+ # See issue Refactor PlotWidget interaction #3292
+ if eventName == 'press' and args[2] == MIDDLE_BTN:
+ self._panStart = args[:2]
+ self._pan.beginDrag(*args)
+ return # Consume middle click events
+ elif eventName == 'release' and args[2] == MIDDLE_BTN:
+ self._panStart = None
+ self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
+ return # Consume middle click events
+ elif self._panStart is not None and eventName == 'move':
+ x, y = args[:2]
+ self._pan.drag(x, y, MIDDLE_BTN)
+
+ super().handleEvent(eventName, *args, **kwargs)
+
def getDescription(self):
"""Returns the dict describing this interactive mode"""
params = self.eventHandlers[1].parameters.copy()
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
index 9f9f846..23b7fe9 100755
--- a/silx/gui/plot/PlotWidget.py
+++ b/silx/gui/plot/PlotWidget.py
@@ -52,7 +52,7 @@ from silx.utils.property import classproperty
from silx.utils.deprecation import deprecated, deprecated_warning
try:
# Import matplotlib now to init matplotlib our way
- from . import matplotlib
+ import silx.gui.utils.matplotlib # noqa
except ImportError:
_logger.debug("matplotlib not available")
@@ -205,6 +205,12 @@ class PlotWidget(qt.QMainWindow):
It provides the visible state.
"""
+ _sigDefaultContextMenu = qt.Signal(qt.QMenu)
+ """Signal emitted when the default context menu of the plot is feed.
+
+ It provides the menu which will be displayed.
+ """
+
def __init__(self, parent=None, backend=None):
self._autoreplot = False
self._dirty = False
@@ -222,8 +228,6 @@ class PlotWidget(qt.QMainWindow):
self.setWindowTitle('PlotWidget')
# Init the backend
- if backend is None:
- backend = silx.config.DEFAULT_PLOT_BACKEND
self._backend = self.__getBackendClass(backend)(self, self)
self.setCallback() # set _callback
@@ -259,6 +263,12 @@ class PlotWidget(qt.QMainWindow):
self._grid = None
self._graphTitle = ''
+ self.__graphCursorShape = 'default'
+
+ # Set axes margins
+ self.__axesDisplayed = True
+ self.__axesMargins = 0., 0., 0., 0.
+ self.setAxesMargins(.15, .1, .1, .15)
self.setGraphTitle()
self.setGraphXLabel()
@@ -314,6 +324,9 @@ class PlotWidget(qt.QMainWindow):
:raise ValueError: In case the backend is not supported
:raise RuntimeError: If a backend is not available
"""
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
if callable(backend):
return backend
@@ -375,6 +388,98 @@ class PlotWidget(qt.QMainWindow):
"""
silx.config.DEFAULT_PLOT_BACKEND = backend
+ def setBackend(self, backend):
+ """Set the backend to use for rendering.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none',
+ a :class:`BackendBase.BackendBase` class.
+ If multiple backends are provided, the first available one is used.
+ :raises ValueError: Unsupported backend descriptor
+ :raises RuntimeError: Error while loading a backend
+ """
+ backend = self.__getBackendClass(backend)(self, self)
+
+ # First save state that is stored in the backend
+ xaxis = self.getXAxis()
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = self.getYAxis(axis='left').getLimits()
+ y2min, y2max = self.getYAxis(axis='right').getLimits()
+ isKeepDataAspectRatio = self.isKeepDataAspectRatio()
+ xTimeZone = xaxis.getTimeZone()
+ isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
+
+ isYAxisInverted = self.getYAxis().isInverted()
+
+ # Remove all items from previous backend
+ for item in self.getItems():
+ item._removeBackendRenderer(self._backend)
+
+ # Switch backend
+ self._backend = backend
+ widget = self._backend.getWidgetHandle()
+ self.setCentralWidget(widget)
+ if widget is None:
+ _logger.info("PlotWidget backend does not support widget")
+
+ # Mark as newly dirty
+ self._dirty = False
+ self._setDirtyPlot()
+
+ # Synchronize/restore state
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ self._backend.setGraphCursorShape(self.getGraphCursorShape())
+ crosshairConfig = self.getGraphCursor()
+ if crosshairConfig is None:
+ self._backend.setGraphCursor(False, 'black', 1, '-')
+ else:
+ self._backend.setGraphCursor(True, *crosshairConfig)
+
+ self._backend.setGraphTitle(self.getGraphTitle())
+ self._backend.setGraphGrid(self.getGraphGrid())
+ if self.isAxesDisplayed():
+ self._backend.setAxesMargins(*self.getAxesMargins())
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+
+ # Set axes
+ xaxis = self.getXAxis()
+ self._backend.setGraphXLabel(xaxis.getLabel())
+ self._backend.setXAxisTimeZone(xTimeZone)
+ self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
+ self._backend.setXAxisLogarithmic(
+ xaxis.getScale() == items.Axis.LOGARITHMIC)
+
+ for axis in ('left', 'right'):
+ self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
+ self._backend.setYAxisInverted(isYAxisInverted)
+ self._backend.setYAxisLogarithmic(
+ self.getYAxis().getScale() == items.Axis.LOGARITHMIC)
+
+ # Finally restore aspect ratio and limits
+ self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
+ self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+
+ # Mark all items for update with new backend
+ for item in self.getItems():
+ item._updated()
+
+ def getBackend(self):
+ """Returns the backend currently used by :class:`PlotWidget`.
+
+ :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase
+ """
+ return self._backend
+
def _getDirtyPlot(self):
"""Return the plot dirty flag.
@@ -403,6 +508,8 @@ class PlotWidget(qt.QMainWindow):
action = ClosePolygonInteractionAction(plot=self, parent=menu)
menu.addAction(action)
+ self._sigDefaultContextMenu.emit(menu)
+
# Make sure the plot is updated, especially when the plot is in
# draw interaction mode
menu.aboutToHide.connect(self.__simulateMouseMove)
@@ -538,6 +645,16 @@ class PlotWidget(qt.QMainWindow):
self._dataBackgroundColor = color
self._backgroundColorsUpdated()
+ dataBackgroundColor = qt.Property(
+ qt.QColor, getDataBackgroundColor, setDataBackgroundColor
+ )
+
+ backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor)
+
+ foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor)
+
+ gridColor = qt.Property(qt.QColor, getGridColor, setGridColor)
+
def showEvent(self, event):
if self._autoreplot and self._dirty:
self._backend.postRedisplay()
@@ -2405,18 +2522,61 @@ class PlotWidget(qt.QMainWindow):
assert(axis in ["left", "right"])
return self._yAxis if axis == "left" else self._yRightAxis
- def setAxesDisplayed(self, displayed):
+ def setAxesDisplayed(self, displayed: bool):
"""Display or not the axes.
:param bool displayed: If `True` axes are displayed. If `False` axes
are not anymore visible and the margin used for them is removed.
"""
- self._backend.setAxesDisplayed(displayed)
- self._setDirtyPlot()
- self._sigAxesVisibilityChanged.emit(displayed)
+ if displayed != self.__axesDisplayed:
+ self.__axesDisplayed = displayed
+ if displayed:
+ self._backend.setAxesMargins(*self.__axesMargins)
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+ self._setDirtyPlot()
+ self._sigAxesVisibilityChanged.emit(displayed)
+
+ def isAxesDisplayed(self) -> bool:
+ """Returns whether or not axes are currently displayed
+
+ :rtype: bool
+ """
+ return self.__axesDisplayed
+
+ def setAxesMargins(
+ self, left: float, top: float, right: float, bottom: float):
+ """Set ratios of margins surrounding data plot area.
+
+ All ratios must be within [0., 1.].
+ Sums of ratios of opposed side must be < 1.
+
+ :param float left: Left-side margin ratio.
+ :param float top: Top margin ratio
+ :param float right: Right-side margin ratio
+ :param float bottom: Bottom margin ratio
+ :raises ValueError:
+ """
+ for value in (left, top, right, bottom):
+ if value < 0. or value > 1.:
+ raise ValueError("Margin ratios must be within [0., 1.]")
+ if left + right >= 1. or top + bottom >= 1.:
+ raise ValueError("Sum of ratios of opposed sides >= 1")
+ margins = left, top, right, bottom
+
+ if margins != self.__axesMargins:
+ self.__axesMargins = margins
+ if self.isAxesDisplayed(): # Only apply if axes are displayed
+ self._backend.setAxesMargins(*margins)
+ self._setDirtyPlot()
- def _isAxesDisplayed(self):
- return self._backend.isAxesDisplayed()
+ def getAxesMargins(self):
+ """Returns ratio of margins surrounding data plot area.
+
+ :return: (left, top, right, bottom)
+ :rtype: List[float]
+ """
+ return self.__axesMargins
def setYAxisInverted(self, flag=True):
"""Set the Y axis orientation.
@@ -2980,11 +3140,19 @@ class PlotWidget(qt.QMainWindow):
# Interaction support
+ def getGraphCursorShape(self):
+ """Returns the current cursor shape.
+
+ :rtype: str
+ """
+ return self.__graphCursorShape
+
def setGraphCursorShape(self, cursor=None):
"""Set the cursor shape.
:param str cursor: Name of the cursor shape
"""
+ self.__graphCursorShape = cursor
self._backend.setGraphCursorShape(cursor)
@deprecated(replacement='getItems', since_version='0.13')
diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py
index a3b70c6..3cd605f 100644
--- a/silx/gui/plot/PlotWindow.py
+++ b/silx/gui/plot/PlotWindow.py
@@ -224,6 +224,56 @@ class PlotWindow(PlotWidget):
self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
self._updateColorBarBackground()
+ if control: # Create control button only if requested
+ self.controlButton = qt.QToolButton()
+ self.controlButton.setText("Options")
+ self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
+ self.controlButton.setAutoRaise(True)
+ self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
+ menu = qt.QMenu(self)
+ menu.aboutToShow.connect(self._customControlButtonMenu)
+ self.controlButton.setMenu(menu)
+
+ self._positionWidget = None
+ if position: # Add PositionInfo widget to the bottom of the plot
+ if isinstance(position, abc.Iterable):
+ # Use position as a set of converters
+ converters = position
+ else:
+ converters = None
+ self._positionWidget = tools.PositionInfo(
+ plot=self, converters=converters)
+ # Set a snapping mode that is consistent with legacy one
+ self._positionWidget.setSnappingMode(
+ tools.PositionInfo.SNAPPING_CROSSHAIR |
+ tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
+ tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
+ tools.PositionInfo.SNAPPING_CURVE |
+ tools.PositionInfo.SNAPPING_SCATTER)
+
+ self.__setCentralWidget()
+
+ # Creating the toolbar also create actions for toolbuttons
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=self)
+ self.addToolBar(self._interactiveModeToolBar)
+
+ self._toolbar = self._createToolBar(title='Plot', parent=self)
+ self.addToolBar(self._toolbar)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
+ self._outputToolBar.getCopyAction().setVisible(copy)
+ self._outputToolBar.getSaveAction().setVisible(save)
+ self._outputToolBar.getPrintAction().setVisible(print_)
+ self.addToolBar(self._outputToolBar)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
+ for action in toolbar.actions():
+ self.addAction(action)
+
+ def __setCentralWidget(self):
+ """Set central widget to host plot backend, colorbar, and bottom bar"""
gridLayout = qt.QGridLayout()
gridLayout.setSpacing(0)
gridLayout.setContentsMargins(0, 0, 0, 0)
@@ -233,42 +283,15 @@ class PlotWindow(PlotWidget):
gridLayout.setColumnStretch(0, 1)
centralWidget = qt.QWidget(self)
centralWidget.setLayout(gridLayout)
- self.setCentralWidget(centralWidget)
- self._positionWidget = None
-
- if control or position:
+ if hasattr(self, "controlButton") or self._positionWidget is not None:
hbox = qt.QHBoxLayout()
hbox.setContentsMargins(0, 0, 0, 0)
- if control:
- self.controlButton = qt.QToolButton()
- self.controlButton.setText("Options")
- self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
- self.controlButton.setAutoRaise(True)
- self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
- menu = qt.QMenu(self)
- menu.aboutToShow.connect(self._customControlButtonMenu)
- self.controlButton.setMenu(menu)
-
+ if hasattr(self, "controlButton"):
hbox.addWidget(self.controlButton)
- if position: # Add PositionInfo widget to the bottom of the plot
- if isinstance(position, abc.Iterable):
- # Use position as a set of converters
- converters = position
- else:
- converters = None
- self._positionWidget = tools.PositionInfo(
- plot=self, converters=converters)
- # Set a snapping mode that is consistent with legacy one
- self._positionWidget.setSnappingMode(
- tools.PositionInfo.SNAPPING_CROSSHAIR |
- tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
- tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
- tools.PositionInfo.SNAPPING_CURVE |
- tools.PositionInfo.SNAPPING_SCATTER)
-
+ if self._positionWidget is not None:
hbox.addWidget(self._positionWidget)
hbox.addStretch(1)
@@ -277,24 +300,12 @@ class PlotWindow(PlotWidget):
gridLayout.addWidget(bottomBar, 1, 0, 1, -1)
- # Creating the toolbar also create actions for toolbuttons
- self._interactiveModeToolBar = tools.InteractiveModeToolBar(
- parent=self, plot=self)
- self.addToolBar(self._interactiveModeToolBar)
-
- self._toolbar = self._createToolBar(title='Plot', parent=self)
- self.addToolBar(self._toolbar)
-
- self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
- self._outputToolBar.getCopyAction().setVisible(copy)
- self._outputToolBar.getSaveAction().setVisible(save)
- self._outputToolBar.getPrintAction().setVisible(print_)
- self.addToolBar(self._outputToolBar)
+ self.setCentralWidget(centralWidget)
- # Activate shortcuts in PlotWindow widget:
- for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
- for action in toolbar.actions():
- self.addAction(action)
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ super(PlotWindow, self).setBackend(backend)
+ self.__setCentralWidget() # Recreate PlotWindow's central widget
@docstring(PlotWidget)
def setBackgroundColor(self, color):
@@ -313,7 +324,7 @@ class PlotWindow(PlotWidget):
def _updateColorBarBackground(self):
"""Update the colorbar background according to the state of the plot"""
- if self._isAxesDisplayed():
+ if self.isAxesDisplayed():
color = self.getBackgroundColor()
else:
color = self.getDataBackgroundColor()
diff --git a/silx/gui/plot/ROIStatsWidget.py b/silx/gui/plot/ROIStatsWidget.py
new file mode 100644
index 0000000..094d66a
--- /dev/null
+++ b/silx/gui/plot/ROIStatsWidget.py
@@ -0,0 +1,780 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides widget for displaying statistics relative to a
+Region of interest and an item
+"""
+
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "22/07/2019"
+
+
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container
+from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.items.roi import RegionOfInterest
+from silx.gui.plot import items as plotitems
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.plot3d import items as plot3ditems
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot import stats as statsmdl
+from collections import OrderedDict
+from silx.utils.proxy import docstring
+import silx.gui.plot.items.marker
+import silx.gui.plot.items.shape
+import functools
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _GetROIItemCoupleDialog(qt.QDialog):
+ """
+ Dialog used to know which plot item and which roi he wants
+ """
+ _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram')
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ qt.QDialog.__init__(self, parent=parent)
+ assert plot is not None
+ assert rois is not None
+ self._plot = plot
+ self._rois = rois
+
+ self.setLayout(qt.QVBoxLayout())
+
+ # define the selection widget
+ self._selection_widget = qt.QWidget()
+ self._selection_widget.setLayout(qt.QHBoxLayout())
+ self._kindCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._kindCB)
+ self._itemCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._itemCB)
+ self._roiCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._roiCB)
+ self.layout().addWidget(self._selection_widget)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self.layout().addWidget(self._buttonsModal)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # connect signal / slot
+ self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi)
+
+ def _getCompatibleRois(self, kind):
+ """Return compatible rois for the given item kind"""
+ def is_compatible(roi, kind):
+ if isinstance(roi, RegionOfInterest):
+ return kind in ('image', 'scatter')
+ elif isinstance(roi, ROI):
+ return kind in ('curve', 'histogram')
+ else:
+ raise ValueError('kind not managed')
+ return list(filter(lambda x: is_compatible(x, kind), self._rois))
+
+ def exec_(self):
+ self._kindCB.clear()
+ self._itemCB.clear()
+ # filter kind without any items
+ self._valid_kinds = {}
+ # key is item type, value kinds
+ self._valid_rois = {}
+ # key is item type, value rois
+ self._kind_name_to_roi = {}
+ # key is (kind, roi name) value is roi
+ self._kind_name_to_item = {}
+ # key is (kind, legend name) value is item
+ for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
+ def getItems(kind):
+ output = []
+ for item in self._plot.getItems():
+ type_ = self._plot._itemKind(item)
+ if type_ in kind and item.isVisible():
+ output.append(item)
+ return output
+
+ items = getItems(kind=kind)
+ rois = self._getCompatibleRois(kind=kind)
+ if len(items) > 0 and len(rois) > 0:
+ self._valid_kinds[kind] = items
+ self._valid_rois[kind] = rois
+ for roi in rois:
+ name = roi.getName()
+ self._kind_name_to_roi[(kind, name)] = roi
+ for item in items:
+ self._kind_name_to_item[(kind, item.getLegend())] = item
+
+ # filter roi according to kinds
+ if len(self._valid_kinds) == 0:
+ _logger.warning('no couple item/roi detected for displaying stats')
+ return self.reject()
+
+ for kind in self._valid_kinds:
+ self._kindCB.addItem(kind)
+ self._updateValidItemAndRoi()
+
+ return qt.QDialog.exec_(self)
+
+ def _updateValidItemAndRoi(self, *args, **kwargs):
+ self._itemCB.clear()
+ self._roiCB.clear()
+ kind = self._kindCB.currentText()
+ for roi in self._valid_rois[kind]:
+ self._roiCB.addItem(roi.getName())
+ for item in self._valid_kinds[kind]:
+ self._itemCB.addItem(item.getLegend())
+
+ def getROI(self):
+ kind = self._kindCB.currentText()
+ roi_name = self._roiCB.currentText()
+ return self._kind_name_to_roi[(kind, roi_name)]
+
+ def getItem(self):
+ kind = self._kindCB.currentText()
+ item_name = self._itemCB.currentText()
+ return self._kind_name_to_item[(kind, item_name)]
+
+
+class ROIStatsItemHelper(object):
+ """Item utils to associate a plot item and a roi
+
+ Display on one row statistics regarding the couple
+ (Item (plot item) / roi).
+
+ :param Item plot_item: item for which we want statistics
+ :param Union[ROI,RegionOfInterest]: region of interest to use for
+ statistics.
+ """
+ def __init__(self, plot_item, roi):
+ self._plot_item = plot_item
+ self._roi = roi
+
+ @property
+ def roi(self):
+ """roi"""
+ return self._roi
+
+ def roi_name(self):
+ if isinstance(self._roi, ROI):
+ return self._roi.getName()
+ elif isinstance(self._roi, RegionOfInterest):
+ return self._roi.getName()
+ else:
+ raise TypeError('Unmanaged roi type')
+
+ @property
+ def roi_kind(self):
+ """roi class"""
+ return self._roi.__class__
+
+ # TODO: should call a util function from the wrapper ?
+ def item_kind(self):
+ """item kind"""
+ if isinstance(self._plot_item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(self._plot_item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(self._plot_item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(self._plot_item, plotitems.Histogram):
+ return 'histogram'
+ elif isinstance(self._plot_item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(self._plot_item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+
+ @property
+ def item_legend(self):
+ """legend of the plot Item"""
+ return self._plot_item.getLegend()
+
+ def id_key(self):
+ """unique key to represent the couple (item, roi)"""
+ return (self.item_kind(), self.item_legend, self.roi_kind,
+ self.roi_name())
+
+
+class _StatsROITable(_StatsWidgetBase, TableWidget):
+ """
+ Table sued to display some statistics regarding a couple (item/roi)
+ """
+ _LEGEND_HEADER_DATA = 'legend'
+
+ _KIND_HEADER_DATA = 'kind'
+
+ _ROI_HEADER_DATA = 'roi'
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent, plot):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+ self.__region_edition_callback = {}
+ """We need to keep trace of the roi signals connection because
+ the roi emits the sigChanged during roi edition"""
+ self._items = {}
+ self.setRowCount(0)
+ self.setColumnCount(3)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+ headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA)
+ self.setHorizontalHeaderItem(2, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ self.__plotItemToItems = {}
+ """Key is plotItem, values is list of __RoiStatsItemWidget"""
+ self.__roiToItems = {}
+ """Key is roi, values is list of __RoiStatsItemWidget"""
+ self.__roisKeyToRoi = {}
+
+ def add(self, item):
+ assert isinstance(item, ROIStatsItemHelper)
+ if item.id_key() in self._items:
+ _logger.warning(item.id_key(), 'is already present')
+ return None
+ self._items[item.id_key()] = item
+ self._addItem(item)
+ return item
+
+ def _addItem(self, item):
+ """
+ Add a _RoiStatsItemWidget item to the table.
+
+ :param item:
+ :return: True if successfully added.
+ """
+ if not isinstance(item, ROIStatsItemHelper):
+ # skipped because also receive all new plot item (Marker...) that
+ # we don't want to manage in this case.
+ return
+ # plotItem = item.getItem()
+ # roi = item.getROI()
+ kind = item.item_kind()
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # register the roi and the kind
+ self._registerPlotItem(item)
+ self._registerROI(item)
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem(), # Kind
+ qt.QTableWidgetItem()] # roi
+
+ for column in range(3, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item._plot_item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+ return True
+
+ def _removeAllItems(self):
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ # item = self._tableItemToItem(tableItem)
+ # item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def clear(self):
+ self._removeAllItems()
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(3 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ else: # Qt4
+ horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def _updateItemObserve(self, *args):
+ pass
+
+ def _dataChanged(self, item):
+ pass
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ assert isinstance(item, ROIStatsItemHelper)
+ plotItem = item._plot_item
+ roi = item._roi
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(plotItem, plot,
+ onlimits=self._statsOnVisibleData,
+ roi=roi, data_changed=data_changed,
+ roi_changed=roi_changed)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(plotItem)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(plotItem))
+ elif name == self._ROI_HEADER_DATA:
+ name = roi.getName()
+ tableItem.setText(name)
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemToItems(self, plotItem):
+ """Return all _RoiStatsItemWidget associated to the plotItem
+ Needed for updating on itemChanged signal
+ """
+ if plotItem in self.__plotItemToItems:
+ return []
+ else:
+ return self.__plotItemToItems[plotItem]
+
+ def _registerPlotItem(self, item):
+ if item._plot_item not in self.__plotItemToItems:
+ self.__plotItemToItems[item._plot_item] = set()
+ self.__plotItemToItems[item._plot_item].add(item)
+
+ def _roiToItems(self, roi):
+ """Return all _RoiStatsItemWidget associated to the roi
+ Needed for updating on roiChanged signal
+ """
+ if roi in self.__roiToItems:
+ return []
+ else:
+ return self.__roiToItems[roi]
+
+ def _registerROI(self, item):
+ if item._roi not in self.__roiToItems:
+ self.__roiToItems[item._roi] = set()
+ # TODO: normalize also sig name
+ if isinstance(item._roi, RegionOfInterest):
+ # item connection within sigRegionChanged should only be
+ # stopped during the region edition
+ self.__region_edition_callback[item._roi] = functools.partial(
+ self._updateAllStats, False, True)
+ item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi])
+ item._roi.sigEditingStarted.connect(functools.partial(
+ self._startFiltering, item._roi))
+ item._roi.sigEditingFinished.connect(functools.partial(
+ self._endFiltering, item._roi))
+ else:
+ item._roi.sigChanged.connect(functools.partial(
+ self._updateAllStats, False, True))
+ self.__roiToItems[item._roi].add(item)
+
+ def _startFiltering(self, roi):
+ roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi])
+
+ def _endFiltering(self, roi):
+ roi.sigRegionChanged.connect(self.__region_edition_callback[roi])
+ self._updateAllStats(roi_changed=True)
+
+ def unregisterROI(self, roi):
+ if roi in self.__roiToItems:
+ del self.__roiToItems[roi]
+ if isinstance(roi, RegionOfInterest):
+ roi.sigRegionEditionStarted.disconnect(functools.partial(
+ self._startFiltering, roi))
+ roi.sigRegionEditionFinished.disconnect(functools.partial(
+ self._startFiltering, roi))
+ try:
+ roi.sigRegionChanged.disconnect(self._updateAllStats)
+ except:
+ pass
+ else:
+ roi.sigChanged.disconnect(self._updateAllStats)
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if event is ItemChangedType.DATA:
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ sender = self.sender()
+ for item in self.__plotItemToItems[sender]:
+ # TODO: get all concerned items
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _removeItem(self, itemKey):
+ if isinstance(itemKey, (silx.gui.plot.items.marker.Marker,
+ silx.gui.plot.items.shape.Shape)):
+ return
+ if itemKey not in self._items:
+ _logger.warning('key not recognized. Won\'t remove any item')
+ return
+ item = self._items[itemKey]
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item._plot_item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+ del self._items[itemKey]
+
+ def _updateAllStats(self, is_request=False, roi_changed=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if (self.getUpdateMode() is UpdateMode.MANUAL and
+ not is_request and not roi_changed):
+ return
+
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, roi_changed=roi_changed,
+ data_changed=is_request)
+
+ def _plotCurrentChanged(self, *args):
+ pass
+
+ def _getRoi(self, kind, name):
+ """return the roi fitting the requirement kind, name. This information
+ is enough to be sure it is unique (in the widget)"""
+ for roi in self.__roiToItems:
+ roiName = roi.getName()
+ if isinstance(roi, kind) and name == roiName:
+ return roi
+ return None
+
+ def _getPlotItem(self, kind, legend):
+ """return the plotItem fitting the requirement kind, legend.
+ This information is enough to be sure it is unique (in the widget)"""
+ for plotItem in self.__plotItemToItems:
+ if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind:
+ return plotItem
+ return None
+
+
+class ROIStatsWidget(qt.QMainWindow):
+ """
+ Widget used to define stats item for a couple(roi, plotItem).
+ Stats will be computing on a given item (curve, image...) in the given
+ region of interest.
+
+ It also provide an interface for adding and removing items.
+
+ .. snapshotqt:: img/ROIStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui import qt
+ from silx.gui.plot import Plot2D
+ from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+ from silx.gui.plot.items.roi import RectangleROI
+ import numpy
+ plot = Plot2D()
+ plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img')
+ plot.show()
+ rectangleROI = RectangleROI()
+ rectangleROI.setGeometry(origin=(0, 100), size=(20, 20))
+ rectangleROI.setName('Initial ROI')
+ widget = ROIStatsWidget(plot=plot)
+ widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)])
+ widget.registerROI(rectangleROI)
+ widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img'))
+ widget.show()
+
+ :param Union[qt.QWidget,None] parent: parent qWidget
+ :param PlotWindow plot: plot widget containing the items
+ :param stats: stats to display
+ :param tuple rois: tuple of rois to manage
+ """
+
+ def __init__(self, parent=None, plot=None, stats=None, rois=None):
+ qt.QMainWindow.__init__(self, parent)
+
+ toolbar = qt.QToolBar(self)
+ icon = icons.getQIcon('add')
+ self._rois = list(rois) if rois is not None else []
+ self._addAction = qt.QAction(icon, 'add item/roi', toolbar)
+ self._addAction.triggered.connect(self._addRoiStatsItem)
+ icon = icons.getQIcon('rm')
+ self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar)
+ self._removeAction.triggered.connect(self._removeCurrentRow)
+
+ toolbar.addAction(self._addAction)
+ toolbar.addAction(self._removeAction)
+ self.addToolBar(toolbar)
+
+ self._plot = plot
+ self._statsROITable = _StatsROITable(parent=self, plot=self._plot)
+ self.setStats(stats=stats)
+ self.setCentralWidget(self._statsROITable)
+ self.setWindowFlags(qt.Qt.Widget)
+
+ # expose API
+ self._setUpdateMode = self._statsROITable.setUpdateMode
+ self._updateAllStats = self._statsROITable._updateAllStats
+
+ # setup
+ self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def registerROI(self, roi):
+ """For now there is no direct link between roi and plot. That is why
+ we need to add/register them to be able to associate them"""
+ self._rois.append(roi)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._plot = plot
+
+ def getPlot(self):
+ return self._plot
+
+ @docstring(_StatsROITable)
+ def setStats(self, stats):
+ if stats is not None:
+ self._statsROITable.setStats(statsHandler=stats)
+
+ @docstring(_StatsROITable)
+ def getStatsHandler(self):
+ """
+
+ :return:
+ """
+ return self._statsROITable.getStatsHandler()
+
+ def _addRoiStatsItem(self):
+ """Ask the user what couple ROI / item he want to display"""
+ dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot,
+ rois=self._rois)
+ if dialog.exec_():
+ self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
+
+ def addItem(self, plotItem, roi):
+ """
+ Add a row of statitstic regarding the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI, RegionOfInterest]
+ :return: None of failed to add the item
+ :rtype: Union[None,ROIStatsItemHelper]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ return self._statsROITable.add(item=statsItem)
+
+ def removeItem(self, plotItem, roi):
+ """
+ Remove the row associated to the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI,RegionOfInterest]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ self._statsROITable._removeItem(itemKey=statsItem.id_key())
+
+ def _removeCurrentRow(self):
+ def is1DKind(kind):
+ if kind in ('curve', 'histogram', 'scatter'):
+ return True
+ else:
+ return False
+
+ currentRow = self._statsROITable.currentRow()
+ item_kind = self._statsROITable.item(currentRow, 1).text()
+ item_legend = self._statsROITable.item(currentRow, 0).text()
+
+ roi_name = self._statsROITable.item(currentRow, 2).text()
+ roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
+ roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
+ if roi is None:
+ _logger.warning('failed to retrieve the roi you want to remove')
+ return False
+ plot_item = self._statsROITable._getPlotItem(kind=item_kind,
+ legend=item_legend)
+ if plot_item is None:
+ _logger.warning('failed to retrieve the plot item you want to'
+ 'remove')
+ return False
+ return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py
index 8ff2483..5ae8653 100644
--- a/silx/gui/plot/ScatterMaskToolsWidget.py
+++ b/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -102,7 +102,7 @@ class ScatterMask(BaseMask):
self._mask[indices] = level
else:
# unmask only where mask level is the specified value
- indices_stencil = numpy.zeros_like(self._mask, dtype=numpy.bool)
+ indices_stencil = numpy.zeros_like(self._mask, dtype=bool)
indices_stencil[indices] = True
self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
self._notify()
@@ -431,7 +431,9 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
filename = dialog.selectedFiles()[0]
dialog.close()
+ # Update the directory according to the user selection
self.maskFileDir = os.path.dirname(filename)
+
try:
self.load(filename)
# except RuntimeWarning as e:
@@ -475,21 +477,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
if os.path.exists(filename):
try:
os.remove(filename)
- except IOError:
+ except IOError as e:
msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
msg.setText("Cannot save.\n"
- "Input Output Error: %s" % (sys.exc_info()[1]))
+ "Input Output Error: %s" % strerror)
msg.exec_()
return
+ # Update the directory according to the user selection
self.maskFileDir = os.path.dirname(filename)
+
try:
self.save(filename, extension[1:])
except Exception as e:
msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
msg.setIcon(qt.QMessageBox.Critical)
- msg.setText("Cannot save file %s\n%s" % (filename, e.args[0]))
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
msg.exec_()
def resetSelectionMask(self):
diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py
index cb7ece1..40e0661 100644
--- a/silx/gui/plot/StackView.py
+++ b/silx/gui/plot/StackView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -548,15 +548,8 @@ class StackView(qt.QMainWindow):
perspective_changed = True
self.setPerspective(perspective)
- # This call to setColormap redefines the meaning of autoscale
- # for 3D volume: take global min/max rather than frame min/max
if self.__autoscaleCmap:
- # note: there is no real autoscale in the stack widget, it is more
- # like a hack computing stack min and max
- colormap = self.getColormap()
- _vmin, _vmax = colormap.getColormapRange(data=self._stack)
- colormap.setVRange(_vmin, _vmax)
- self.setColormap(colormap=colormap)
+ self.scaleColormapRangeToStack()
# init plot
self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
@@ -791,6 +784,22 @@ class StackView(qt.QMainWindow):
# specifying a special colormap
return self._plot.getDefaultColormap()
+ def scaleColormapRangeToStack(self):
+ """Scale colormap range according to current stack data.
+
+ If no stack has been set through :meth:`setStack`, this has no effect.
+
+ The range scaling mode is given by current :class:`Colormap`'s
+ :meth:`Colormap.getAutoscaleMode`.
+ """
+ stack = self.getStack(copy=False, returnNumpyArray=True)
+ if stack is None:
+ return # No-op
+
+ colormap = self.getColormap()
+ vmin, vmax = colormap.getColormapRange(data=stack[0])
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+
def setColormap(self, colormap=None, normalization=None,
autoscale=None, vmin=None, vmax=None, colors=None):
"""Set the colormap and update active image.
@@ -860,31 +869,14 @@ class StackView(qt.QMainWindow):
vmax=vmax,
colors=colors)
- # Patch: since we don't apply this colormap to a single 2D data but
- # a 2D stack we have to deal manually with vmin, vmax
- if autoscale is None:
- # set default
- autoscale = False
- elif autoscale and is_dataset(self._stack):
- # h5py dataset has no min()/max() methods
- raise RuntimeError(
- "Cannot auto-scale colormap for a h5py dataset")
- else:
- autoscale = autoscale
- self.__autoscaleCmap = autoscale
-
- if autoscale and (self._stack is not None):
- _vmin, _vmax = _colormap.getColormapRange(data=self._stack)
- _colormap.setVRange(vmin=_vmin, vmax=_vmax)
- else:
- if vmin is None and self._stack is not None:
- _colormap.setVMin(self._stack.min())
- else:
- _colormap.setVMin(vmin)
- if vmax is None and self._stack is not None:
- _colormap.setVMax(self._stack.max())
- else:
- _colormap.setVMax(vmax)
+ if autoscale is not None:
+ deprecated_warning(
+ type_='function',
+ name='setColormap',
+ reason='autoscale argument is replaced by a method',
+ replacement='scaleColormapRangeToStack',
+ since_version='0.14')
+ self.__autoscaleCmap = bool(autoscale)
cursorColor = cursorColorForColormap(_colormap.getName())
self._plot.setInteractiveMode('zoom', color=cursorColor)
@@ -896,6 +888,12 @@ class StackView(qt.QMainWindow):
if isinstance(activeImage, items.ColormapMixIn):
activeImage.setColormap(self.getColormap())
+ if self.__autoscaleCmap:
+ # scaleColormapRangeToStack needs to be called **after**
+ # setDefaultColormap so getColormap returns the right colormap
+ self.scaleColormapRangeToStack()
+
+
@deprecated(replacement="getPlotWidget", since_version="0.13")
def getPlot(self):
return self.getPlotWidget()
diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py
index 6b92ea0..26b48db 100644
--- a/silx/gui/plot/StatsWidget.py
+++ b/silx/gui/plot/StatsWidget.py
@@ -449,10 +449,12 @@ class _StatsWidgetBase(object):
_displayOnlyActItem option."""
raise NotImplementedError('Base class')
- def _updateStats(self, item):
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
"""Update displayed information for given plot item
:param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
"""
raise NotImplementedError('Base class')
@@ -548,7 +550,7 @@ class _StatsWidgetBase(object):
class StatsTable(_StatsWidgetBase, TableWidget):
"""
- TableWidget displaying for each curves contained by the Plot some
+ TableWidget displaying for each items contained by the Plot some
information:
* legend
@@ -582,10 +584,10 @@ class StatsTable(_StatsWidgetBase, TableWidget):
self.setColumnCount(2)
# Init headers
- headerItem = qt.QTableWidgetItem('Legend')
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
self.setHorizontalHeaderItem(0, headerItem)
- headerItem = qt.QTableWidgetItem('Kind')
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
self.setHorizontalHeaderItem(1, headerItem)
@@ -750,7 +752,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
return
else:
item = self.sender()
- self._updateStats(item)
+ self._updateStats(item, data_changed=True)
# deal with stat items visibility
if event is ItemChangedType.VISIBLE:
if len(self._itemToTableItems(item).items()) > 0:
@@ -812,7 +814,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
self.setItem(row, column, tableItem)
# Update table items content
- self._updateStats(item)
+ self._updateStats(item, data_changed=True)
# Listen for item changes
# Using queued connection to avoid issue with sender
@@ -845,10 +847,12 @@ class StatsTable(_StatsWidgetBase, TableWidget):
self.clearContents()
self.setRowCount(0)
- def _updateStats(self, item):
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
"""Update displayed information for given plot item
:param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
"""
if item is None:
return
@@ -865,7 +869,8 @@ class StatsTable(_StatsWidgetBase, TableWidget):
statsHandler = self.getStatsHandler()
if statsHandler is not None:
stats = statsHandler.calculate(
- item, plot, self._statsOnVisibleData)
+ item, plot, self._statsOnVisibleData,
+ data_changed=data_changed, roi_changed=roi_changed)
else:
stats = {}
@@ -895,7 +900,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
for row in range(self.rowCount()):
tableItem = self.item(row, 0)
item = self._tableItemToItem(tableItem)
- self._updateStats(item)
+ self._updateStats(item, data_changed=is_request)
def _currentItemChanged(self, current, previous):
"""Handle change of selection in table and sync plot selection
@@ -1392,7 +1397,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
"""
return self._item_kind
- def _setItem(self, item):
+ def _setItem(self, item, data_changed=True):
if item is None:
for stat_name, stat_widget in self._statQlineEdit.items():
stat_widget.setText('')
@@ -1402,7 +1407,8 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
if plot is not None:
statsValDict = self._statsHandler.calculate(item,
plot,
- self._statsOnVisibleData)
+ self._statsOnVisibleData,
+ data_changed=data_changed)
for statName, statVal in list(statsValDict.items()):
self._statQlineEdit[statName].setText(statVal)
@@ -1417,7 +1423,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
items = list(filter(kind_filter, _items))
assert len(items) in (0, 1)
_item = items[0] if len(items) == 1 else None
- self._setItem(_item)
+ self._setItem(_item, data_changed=True)
def _updateCurrentItem(self):
self._updateItemObserve()
@@ -1432,7 +1438,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
def _removeItem(self, item):
raise NotImplementedError('Display only the active item')
- def _plotCurrentChanged(selfself, current):
+ def _plotCurrentChanged(self, current):
raise NotImplementedError('Display only the active item')
def _updateModeHasChanged(self):
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py
index aa4921c..3298498 100644
--- a/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -378,7 +378,7 @@ class BaseMaskToolsWidget(qt.QWidget):
"""
super(BaseMaskToolsWidget, self).__init__(parent)
# register if the user as force a color for the corresponding mask level
- self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=numpy.bool)
+ self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool)
# overlays colors set by the user
self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32)
@@ -459,6 +459,18 @@ class BaseMaskToolsWidget(qt.QWidget):
self._levelWidget.setVisible(self._multipleMasks != 'single')
self._clearAllBtn.setVisible(self._multipleMasks != 'single')
+ def setMaskFileDirectory(self, path):
+ """Set the default directory to use by load/save GUI tools
+
+ The directory is also updated by the user, if he change the location
+ of the dialog.
+ """
+ self.maskFileDir = path
+
+ def getMaskFileDirectory(self):
+ """Get the default directory used by load/save GUI tools"""
+ return self.maskFileDir
+
@property
def maskFileDir(self):
"""The directory from which to load/save mask from/to files."""
diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py
index 23c9dce..ebf775b 100644
--- a/silx/gui/plot/_utils/dtime_ticklayout.py
+++ b/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -166,7 +166,7 @@ def setDateElement(dateTime, value, unit):
def roundToElement(dateTime, unit):
- """ Returns a copy of dateTime with the
+ """ Returns a copy of dateTime rounded to given unit
:param datetime.datetime: date time object
:param DtUnit unit: unit
@@ -330,15 +330,19 @@ def niceDateTimeElement(value, unit, isRound=False):
def findStartDate(dMin, dMax, nTicks):
""" Rounds a date down to the nearest nice number of ticks
"""
- assert dMax > dMin, \
+ assert dMax >= dMin, \
"dMin ({}) should come before dMax ({})".format(dMin, dMax)
+ if dMin == dMax:
+ # Fallback when range is smaller than microsecond resolution
+ return dMin, 1, DtUnit.MICRO_SECONDS
+
delta = dMax - dMin
lengthSec = delta.total_seconds()
_logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)"
.format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY))
- length, unit = bestUnit(delta.total_seconds())
+ length, unit = bestUnit(lengthSec)
niceLength = niceDateTimeElement(length, unit)
_logger.debug("Length: {:8.3f} {} (nice = {})"
@@ -381,9 +385,9 @@ def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False):
"""
if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or
unit == DtUnit.MICRO_SECONDS):
-
- # Month and years will be converted to integers
- assert int(step) > 0, "Integer value or tickstep is 0"
+ # No support for fractional month or year and resolution is microsecond
+ # In those cases, make sure the step is at least 1
+ step = max(1, step)
else:
assert step > 0, "tickstep is 0"
diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py
index ba69748..182ac78 100755
--- a/silx/gui/plot/actions/control.py
+++ b/silx/gui/plot/actions/control.py
@@ -50,7 +50,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "27/11/2020"
from . import PlotAction
import logging
@@ -322,6 +322,7 @@ class ColormapAction(PlotAction):
:param plot: :class:`.PlotWidget` instance on which to operate
:param parent: See :class:`QAction`
"""
+
def __init__(self, plot, parent=None):
self._dialog = None # To store an instance of ColormapDialog
super(ColormapAction, self).__init__(
@@ -418,6 +419,7 @@ class ColorBarAction(PlotAction):
:param plot: :class:`.PlotWidget` instance on which to operate
:param parent: See :class:`QAction`
"""
+
def __init__(self, plot, parent=None):
self._dialog = None # To store an instance of ColorBar
super(ColorBarAction, self).__init__(
@@ -597,7 +599,7 @@ class ShowAxisAction(PlotAction):
triggered=self._actionTriggered,
checkable=True,
parent=parent)
- self.setChecked(self.plot._backend.isAxesDisplayed())
+ self.setChecked(self.plot.isAxesDisplayed())
plot._sigAxesVisibilityChanged.connect(self.setChecked)
def _actionTriggered(self, checked=False):
@@ -632,3 +634,76 @@ class ClosePolygonInteractionAction(PlotAction):
def _actionTriggered(self, checked=False):
self.plot._eventHandler.validate()
+
+
+class OpenGLAction(PlotAction):
+ """QAction controlling rendering of a :class:`.PlotWidget`.
+
+ For now it can enable or not the OpenGL backend.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ "opengl": (icons.getQIcon('backend-opengl'),
+ "OpenGL rendering (fast)\nClick to disable OpenGL"),
+ "matplotlib": (icons.getQIcon('backend-opengl'),
+ "Matplotlib rendering (safe)\nClick to enable OpenGL"),
+ "unknown": (icons.getQIcon('backend-opengl'),
+ "Custom rendering")
+ }
+
+ name = self._getBackendName(plot)
+ self.__state = name
+ icon, tooltip = self._states[name]
+ super(OpenGLAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Enable/disable OpenGL rendering',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent)
+
+ def _backendUpdated(self):
+ name = self._getBackendName(self.plot)
+ self.__state = name
+ icon, tooltip = self._states[name]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+ self.setChecked(name == "opengl")
+
+ def _getBackendName(self, plot):
+ backend = plot.getBackend()
+ name = type(backend).__name__.lower()
+ if "opengl" in name:
+ return "opengl"
+ elif "matplotlib" in name:
+ return "matplotlib"
+ else:
+ return "unknown"
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ name = self._getBackendName(self.plot)
+ if self.__state != name:
+ # THere is no event to know the backend was updated
+ # So here we check if there is a mismatch between the displayed state
+ # and the real state of the widget
+ self._backendUpdated()
+ return
+ if name != "opengl":
+ from silx.gui.utils import glutils
+ result = glutils.isOpenGLAvailable()
+ if not result:
+ qt.QMessageBox.critical(plot, "OpenGL rendering not available", result.error)
+ # Uncheck if needed
+ self._backendUpdated()
+ return
+ plot.setBackend("opengl")
+ else:
+ plot.setBackend("matplotlib")
+ self._backendUpdated()
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py
index 43b3b3a..f728b7a 100644
--- a/silx/gui/plot/actions/io.py
+++ b/silx/gui/plot/actions/io.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -37,7 +37,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "12/07/2018"
+__date__ = "25/09/2020"
from . import PlotAction
from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
@@ -224,6 +224,43 @@ class SaveAction(PlotAction):
ylabel = item.getYLabel() or self.plot.getYAxis().getLabel()
return xlabel, ylabel
+ def _get1dData(self, item):
+ "provide xdata, [ydata], xlabel, [ylabel] and manages error bars"
+ xlabel, ylabel = self._getAxesLabels(item)
+ x_data = item.getXData(copy=False)
+ y_data = item.getYData(copy=False)
+ x_err = item.getXErrorData(copy=False)
+ y_err = item.getYErrorData(copy=False)
+ labels = [ylabel]
+ data = [y_data]
+
+ if x_err is not None:
+ if numpy.isscalar(x_err):
+ data.append(numpy.zeros_like(y_data) + x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 1:
+ data.append(x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 2:
+ data.append(x_err[0])
+ labels.append(xlabel + "_errors_below")
+ data.append(x_err[1])
+ labels.append(xlabel + "_errors_above")
+
+ if y_err is not None:
+ if numpy.isscalar(y_err):
+ data.append(numpy.zeros_like(y_data) + y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 1:
+ data.append(y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 2:
+ data.append(y_err[0])
+ labels.append(ylabel + "_errors_below")
+ data.append(y_err[1])
+ labels.append(ylabel + "_errors_above")
+ return x_data, data, xlabel, labels
+
@staticmethod
def _selectWriteableOutputGroup(filename, parent):
if os.path.exists(filename) and os.path.isfile(filename) \
@@ -291,16 +328,15 @@ class SaveAction(PlotAction):
# .npy or nxdata
fmt, csvdelim, autoheader = ("", "", False)
- xlabel, ylabel = self._getAxesLabels(curve)
-
if nameFilter == self.CURVE_FILTER_NXDATA:
return self._saveCurveAsNXdata(curve, filename)
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
try:
save1D(filename,
- curve.getXData(copy=False),
- curve.getYData(copy=False),
- xlabel, [ylabel],
+ xdata, data,
+ xlabel, labels,
fmt=fmt, csvdelim=csvdelim,
autoheader=autoheader)
except IOError:
@@ -328,13 +364,11 @@ class SaveAction(PlotAction):
curve = curves[0]
scanno = 1
try:
- xlabel = curve.getXLabel() or plot.getGraphXLabel()
- ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis())
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
specfile = savespec(filename,
- curve.getXData(copy=False),
- curve.getYData(copy=False),
- xlabel,
- ylabel,
+ xdata, data,
+ xlabel, labels,
fmt="%.7g", scan_number=1, mode="w",
write_file_header=True,
close_file=False)
@@ -345,13 +379,10 @@ class SaveAction(PlotAction):
for curve in curves[1:]:
try:
scanno += 1
- xlabel = curve.getXLabel() or plot.getGraphXLabel()
- ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis())
+ xdata, data, xlabel, labels = self._get1dData(curve)
specfile = savespec(specfile,
- curve.getXData(copy=False),
- curve.getYData(copy=False),
- xlabel,
- ylabel,
+ xdata, data,
+ xlabel, labels,
fmt="%.7g", scan_number=scanno,
write_file_header=False,
close_file=False)
@@ -629,7 +660,7 @@ class SaveAction(PlotAction):
# Check for correct file extension
# Extract file extensions as .something
extensions = [ext[ext.find('.'):] for ext in
- nameFilter[nameFilter.find('(')+1:-1].split()]
+ nameFilter[nameFilter.find('(') + 1:-1].split()]
for ext in extensions:
if (len(filename) > len(ext) and
filename[-len(ext):].lower() == ext.lower()):
diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py
index bcc93a5..6fc1aa7 100755
--- a/silx/gui/plot/backends/BackendBase.py
+++ b/silx/gui/plot/backends/BackendBase.py
@@ -58,8 +58,8 @@ class BackendBase(object):
self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
self.__yAxisInverted = False
self.__keepDataAspectRatio = False
+ self.__xAxisTimeSeries = False
self._xAxisTimeZone = None
- self._axesDisplayed = True
# Store a weakref to get access to the plot state.
self._setPlot(plot)
@@ -457,14 +457,14 @@ class BackendBase(object):
:rtype: bool
"""
- raise NotImplementedError()
+ return self.__xAxisTimeSeries
def setXAxisTimeSeries(self, isTimeSeries):
"""Set whether the X-axis is a time series
:param bool flag: True to switch to time series, False for regular axis.
"""
- raise NotImplementedError()
+ self.__xAxisTimeSeries = bool(isTimeSeries)
def setXAxisLogarithmic(self, flag):
"""Set the X axis scale between linear and log.
@@ -548,20 +548,17 @@ class BackendBase(object):
"""
raise NotImplementedError()
- def setAxesDisplayed(self, displayed):
- """Display or not the axes.
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ """Set the size of plot margins as ratios.
- :param bool displayed: If `True` axes are displayed. If `False` axes
- are not anymore visible and the margin used for them is removed.
- """
- self._axesDisplayed = displayed
+ Values are expected in [0., 1.]
- def isAxesDisplayed(self):
- """private because in some case it is possible that one of the two axes
- are displayed and not the other.
- This only check status set to axes from the public API
+ :param float left:
+ :param float top:
+ :param float right:
+ :param float bottom:
"""
- return self._axesDisplayed
+ pass
def setForegroundColors(self, foregroundColor, gridColor):
"""Set foreground and grid colors used to display this widget.
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
index 036e630..140672f 100755
--- a/silx/gui/plot/backends/BackendMatplotlib.py
+++ b/silx/gui/plot/backends/BackendMatplotlib.py
@@ -33,6 +33,7 @@ __date__ = "21/12/2018"
import logging
import datetime as dt
+from typing import Tuple
import numpy
from pkg_resources import parse_version as _parse_version
@@ -44,7 +45,7 @@ _logger = logging.getLogger(__name__)
from ... import qt
# First of all init matplotlib and set its backend
-from ..matplotlib import FigureCanvasQTAgg
+from ...utils.matplotlib import FigureCanvasQTAgg
import matplotlib
from matplotlib.container import Container
from matplotlib.figure import Figure
@@ -593,7 +594,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
if (len(color) == 4 and
type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
- color = numpy.array(color, dtype=numpy.float) / 255.
+ color = numpy.array(color, dtype=numpy.float64) / 255.
if yaxis == "right":
axes = self.ax2
@@ -601,7 +602,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
else:
axes = self.ax
- picker = 3
+ pickradius = 3
artists = [] # All the artists composing the curve
@@ -627,7 +628,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
if hasattr(color, 'dtype') and len(color) == len(x):
# scatter plot
- if color.dtype not in [numpy.float32, numpy.float]:
+ if color.dtype not in [numpy.float32, numpy.float64]:
actualColor = color / 255.
else:
actualColor = color
@@ -639,7 +640,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
linestyle=linestyle,
color=actualColor[0],
linewidth=linewidth,
- picker=picker,
+ picker=True,
+ pickradius=pickradius,
marker=None)
artists += list(curveList)
@@ -647,7 +649,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
scatter = axes.scatter(x, y,
color=actualColor,
marker=marker,
- picker=picker,
+ picker=True,
+ pickradius=pickradius,
s=symbolsize**2)
artists.append(scatter)
@@ -665,7 +668,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
color=color,
linewidth=linewidth,
marker=symbol,
- picker=picker,
+ picker=True,
+ pickradius=pickradius,
markersize=symbolsize)
artists += list(curveList)
@@ -744,13 +748,13 @@ class BackendMatplotlib(BackendBase.BackendBase):
color = numpy.array(color, copy=False)
assert color.ndim == 2 and len(color) == len(x)
- if color.dtype not in [numpy.float32, numpy.float]:
+ if color.dtype not in [numpy.float32, numpy.float64]:
color = color.astype(numpy.float32) / 255.
collection = TriMesh(
Triangulation(x, y, triangles),
alpha=alpha,
- picker=0) # 0 enables picking on filled triangle
+ pickradius=0) # 0 enables picking on filled triangle
collection.set_color(color)
self.ax.add_collection(collection)
@@ -893,7 +897,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
else:
raise RuntimeError('A marker must at least have one coordinate')
- line.set_picker(5)
+ line.set_picker(True)
+ line.set_pickradius(5)
# All markers are overlays
line.set_animated(True)
@@ -1014,7 +1019,11 @@ class BackendMatplotlib(BackendBase.BackendBase):
lambda item: item.isVisible() and item._backendRenderer is not None)
count = len(items)
for index, item in enumerate(items):
- zorder = 1. + index / count
+ if item.getZValue() < 0.5:
+ # Make sure matplotlib z order is below the grid (with z=0.5)
+ zorder = 0.5 * index / count
+ else: # Make sure matplotlib z order is above the grid (> 0.5)
+ zorder = 1. + index / count
if zorder != item._backendRenderer.get_zorder():
item._backendRenderer.set_zorder(zorder)
@@ -1196,67 +1205,58 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Data <-> Pixel coordinates conversion
- def _mplQtYAxisCoordConversion(self, y, asint=True):
- """Qt origin (top) to/from matplotlib origin (bottom) conversion.
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ return 1.
- :param y:
- :param bool asint: True to cast to int, False to keep as float
+ def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert matplotlib "display" space coord to Qt widget logical pixel
+ """
+ ratio = self._getDevicePixelRatio()
+ # Convert from matplotlib origin (bottom) to Qt origin (top)
+ # and apply device pixel ratio
+ return x / ratio, (self.fig.get_window_extent().height - y) / ratio
- :rtype: float
+ def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert Qt widget logical pixel to matplotlib "display" space coord
"""
- value = self.fig.get_window_extent().height - y
- return int(value) if asint else value
+ ratio = self._getDevicePixelRatio()
+ # Apply device pixel ration and
+ # convert from Qt origin (top) to matplotlib origin (bottom)
+ return x * ratio, self.fig.get_window_extent().height - (y * ratio)
def dataToPixel(self, x, y, axis):
ax = self.ax2 if axis == "right" else self.ax
-
- pixels = ax.transData.transform_point((x, y))
- xPixel, yPixel = pixels.T
-
- # Convert from matplotlib origin (bottom) to Qt origin (top)
- yPixel = self._mplQtYAxisCoordConversion(yPixel, asint=False)
-
- return xPixel, yPixel
+ displayPos = ax.transData.transform_point((x, y)).transpose()
+ return self._mplToQtPosition(*displayPos)
def pixelToData(self, x, y, axis):
ax = self.ax2 if axis == "right" else self.ax
-
- # Convert from Qt origin (top) to matplotlib origin (bottom)
- y = self._mplQtYAxisCoordConversion(y, asint=False)
-
- inv = ax.transData.inverted()
- x, y = inv.transform_point((x, y))
- return x, y
+ displayPos = self._qtToMplPosition(x, y)
+ return tuple(ax.transData.inverted().transform_point(displayPos))
def getPlotBoundsInPixels(self):
bbox = self.ax.get_window_extent()
# Warning this is not returning int...
- return (int(bbox.xmin),
- self._mplQtYAxisCoordConversion(bbox.ymax, asint=True),
- int(bbox.width),
- int(bbox.height))
+ ratio = self._getDevicePixelRatio()
+ return tuple(int(value / ratio) for value in (
+ bbox.xmin,
+ self.fig.get_window_extent().height - bbox.ymax,
+ bbox.width,
+ bbox.height))
- def setAxesDisplayed(self, displayed):
- """Display or not the axes.
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ width, height = 1. - left - right, 1. - top - bottom
+ position = left, bottom, width, height
+
+ # Toggle display of axes and viewbox rect
+ isFrameOn = position != (0., 0., 1., 1.)
+ self.ax.set_frame_on(isFrameOn)
+ self.ax2.set_frame_on(isFrameOn)
+
+ self.ax.set_position(position)
+ self.ax2.set_position(position)
- :param bool displayed: If `True` axes are displayed. If `False` axes
- are not anymore visible and the margin used for them is removed.
- """
- BackendBase.BackendBase.setAxesDisplayed(self, displayed)
- if displayed:
- # show axes and viewbox rect
- self.ax.set_frame_on(True)
- self.ax2.set_frame_on(True)
- # set the default margins
- self.ax.set_position([.15, .15, .75, .75])
- self.ax2.set_position([.15, .15, .75, .75])
- else:
- # hide axes and viewbox rect
- self.ax.set_frame_on(False)
- self.ax2.set_frame_on(False)
- # remove external margins
- self.ax.set_position([0, 0, 1, 1])
- self.ax2.set_position([0, 0, 1, 1])
self._synchronizeBackgroundColors()
self._synchronizeForegroundColors()
self._plot._setDirtyPlot()
@@ -1349,6 +1349,15 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
def postRedisplay(self):
self._sigPostRedisplay.emit()
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ if hasattr(self, 'devicePixelRatioF'):
+ ratio = self.devicePixelRatioF()
+ else: # Qt < 5.6 compatibility
+ ratio = float(self.devicePixelRatio())
+ # Safety net: avoid returning 0
+ return ratio if ratio != 0. else 1.
+
# Mouse event forwarding
_MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
@@ -1356,17 +1365,14 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
def _onMousePress(self, event):
button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
if button is not None:
- self._plot.onMousePress(
- event.x, self._mplQtYAxisCoordConversion(event.y),
- button)
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMousePress(int(x), int(y), button)
def _onMouseMove(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
if self._graphCursor:
position = self._plot.pixelToData(
- event.x,
- self._mplQtYAxisCoordConversion(event.y),
- axis='left',
- check=True)
+ x, y, axis='left', check=True)
lineh, linev = self._graphCursor
if position is not None:
linev.set_visible(True)
@@ -1380,19 +1386,17 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
self._plot._setDirtyPlot(overlayOnly=True)
# onMouseMove must trigger replot if dirty flag is raised
- self._plot.onMouseMove(
- event.x, self._mplQtYAxisCoordConversion(event.y))
+ self._plot.onMouseMove(int(x), int(y))
def _onMouseRelease(self, event):
button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
if button is not None:
- self._plot.onMouseRelease(
- event.x, self._mplQtYAxisCoordConversion(event.y),
- button)
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseRelease(int(x), int(y), button)
def _onMouseWheel(self, event):
- self._plot.onMouseWheel(
- event.x, self._mplQtYAxisCoordConversion(event.y), event.step)
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseWheel(int(x), int(y), event.step)
def leaveEvent(self, event):
"""QWidget event handler"""
@@ -1406,8 +1410,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
# picking
def pickItem(self, x, y, item):
+ xDisplay, yDisplay = self._qtToMplPosition(x, y)
mouseEvent = MouseEvent(
- 'button_press_event', self, x, self._mplQtYAxisCoordConversion(y))
+ 'button_press_event', self, int(xDisplay), int(yDisplay))
# Override axes and data position with the axes
mouseEvent.inaxes = item.axes
mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
index cf1da31..909d18a 100755
--- a/silx/gui/plot/backends/BackendOpenGL.py
+++ b/silx/gui/plot/backends/BackendOpenGL.py
@@ -43,12 +43,7 @@ from ... import qt
from ..._glutils import gl
from ... import _glutils as glu
-from .glutils import (
- GLLines2D, GLPlotTriangles,
- GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D,
- mat4Ortho, mat4Identity,
- LEFT, RIGHT, BOTTOM, TOP,
- Text2D, FilledShape2D)
+from . import glutils
from .glutils.PlotImageFile import saveImageToFile
_logger = logging.getLogger(__name__)
@@ -216,7 +211,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._backgroundColor = 1., 1., 1., 1.
self._dataBackgroundColor = 1., 1., 1., 1.
- self.matScreenProj = mat4Identity()
+ self.matScreenProj = glutils.mat4Identity()
self._progBase = glu.Program(
_baseVertShd, _baseFragShd, attrib0='position')
@@ -231,10 +226,13 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._glGarbageCollector = []
- self._plotFrame = GLPlotFrame2D(
+ self._plotFrame = glutils.GLPlotFrame2D(
foregroundColor=(0., 0., 0., 1.),
gridColor=(.7, .7, .7, 1.),
- margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50})
+ marginRatios=(.15, .1, .1, .15))
+ self._plotFrame.size = ( # Init size with size int
+ int(self.getDevicePixelRatio() * 640),
+ int(self.getDevicePixelRatio() * 480))
# Make postRedisplay asynchronous using Qt signal
self._sigPostRedisplay.connect(
@@ -254,50 +252,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def mousePressEvent(self, event):
if event.button() not in self._MOUSE_BTNS:
return super(BackendOpenGL, self).mousePressEvent(event)
- xPixel = event.x() * self.getDevicePixelRatio()
- yPixel = event.y() * self.getDevicePixelRatio()
- btn = self._MOUSE_BTNS[event.button()]
- self._plot.onMousePress(xPixel, yPixel, btn)
+ self._plot.onMousePress(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
event.accept()
def mouseMoveEvent(self, event):
- xPixel = event.x() * self.getDevicePixelRatio()
- yPixel = event.y() * self.getDevicePixelRatio()
-
- # Handle crosshair
- inXPixel, inYPixel = self._mouseInPlotArea(xPixel, yPixel)
- isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+ qtPos = event.x(), event.y()
previousMousePosInPixels = self._mousePosInPixels
- self._mousePosInPixels = (xPixel, yPixel) if isCursorInPlot else None
+ if qtPos == self._mouseInPlotArea(*qtPos):
+ devicePixelRatio = self.getDevicePixelRatio()
+ devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio
+ self._mousePosInPixels = devicePos # Mouse in plot area
+ else:
+ self._mousePosInPixels = None # Mouse outside plot area
+
if (self._crosshairCursor is not None and
previousMousePosInPixels != self._mousePosInPixels):
# Avoid replot when cursor remains outside plot area
self._plot._setDirtyPlot(overlayOnly=True)
- self._plot.onMouseMove(xPixel, yPixel)
+ self._plot.onMouseMove(*qtPos)
event.accept()
def mouseReleaseEvent(self, event):
if event.button() not in self._MOUSE_BTNS:
return super(BackendOpenGL, self).mouseReleaseEvent(event)
- xPixel = event.x() * self.getDevicePixelRatio()
- yPixel = event.y() * self.getDevicePixelRatio()
-
- btn = self._MOUSE_BTNS[event.button()]
- self._plot.onMouseRelease(xPixel, yPixel, btn)
+ self._plot.onMouseRelease(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
event.accept()
def wheelEvent(self, event):
- xPixel = event.x() * self.getDevicePixelRatio()
- yPixel = event.y() * self.getDevicePixelRatio()
-
if hasattr(event, 'angleDelta'): # Qt 5
delta = event.angleDelta().y()
else: # Qt 4 support
delta = event.delta()
angleInDegrees = delta / 8.
- self._plot.onMouseWheel(xPixel, yPixel, angleInDegrees)
+ self._plot.onMouseWheel(event.x(), event.y(), angleInDegrees)
event.accept()
def leaveEvent(self, _):
@@ -371,7 +362,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
- mat4Identity().astype(numpy.float32))
+ glutils.mat4Identity().astype(numpy.float32))
gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
gl.glVertexAttribPointer(self._progTex.attributes['position'],
@@ -405,10 +396,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
# Check if window is large enough
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
- if plotWidth <= 2 or plotHeight <= 2:
+ if self._plotFrame.plotSize <= (2, 2):
return
+ # Sync plot frame with window
+ self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
# self._paintDirectGL()
self._paintFBOGL()
@@ -422,7 +414,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
True to render items that are overlays.
"""
# Values that are often used
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ plotWidth, plotHeight = self._plotFrame.plotSize
isXLog = self._plotFrame.xAxis.isLog
isYLog = self._plotFrame.yAxis.isLog
isYInverted = self._plotFrame.isYAxisInverted
@@ -431,6 +423,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
labels = []
pixelOffset = 3
+ context = glutils.RenderContext(
+ isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch())
+
for plotItem in self.getItemsFromBackToFront(
condition=lambda i: i.isVisible() and i.isOverlay() == overlay):
if plotItem._backendRenderer is None:
@@ -438,20 +433,16 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
item = plotItem._backendRenderer
- if isinstance(item, (GLPlotCurve2D,
- GLPlotColormap,
- GLPlotRGBAImage,
- GLPlotTriangles)): # Render data items
+ if isinstance(item, glutils.GLPlotItem): # Render data items
gl.glViewport(self._plotFrame.margins.left,
self._plotFrame.margins.bottom,
plotWidth, plotHeight)
-
- if isinstance(item, GLPlotCurve2D) and item.info.get('yAxis') == 'right':
- item.render(self._plotFrame.transformedDataY2ProjMat,
- isXLog, isYLog)
+ # Set matrix
+ if item.yaxis == 'right':
+ context.matrix = self._plotFrame.transformedDataY2ProjMat
else:
- item.render(self._plotFrame.transformedDataProjMat,
- isXLog, isYLog)
+ context.matrix = self._plotFrame.transformedDataProjMat
+ item.render(context)
elif isinstance(item, _ShapeItem): # Render shape items
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
@@ -463,53 +454,67 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if item['shape'] == 'hline':
width = self._plotFrame.size[0]
- _, yPixel = self._plot.dataToPixel(
- None, item['y'], axis='left', check=False)
- points = numpy.array(((0., yPixel), (width, yPixel)),
- dtype=numpy.float32)
+ _, yPixel = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ item['y'],
+ axis='left')
+ subShapes = [numpy.array(((0., yPixel), (width, yPixel)),
+ dtype=numpy.float32)]
elif item['shape'] == 'vline':
- xPixel, _ = self._plot.dataToPixel(
- item['x'], None, axis='left', check=False)
+ xPixel, _ = self._plotFrame.dataToPixel(
+ item['x'],
+ 0.5 * sum(self._plotFrame.dataRanges[1]),
+ axis='left')
height = self._plotFrame.size[1]
- points = numpy.array(((xPixel, 0), (xPixel, height)),
- dtype=numpy.float32)
+ subShapes = [numpy.array(((xPixel, 0), (xPixel, height)),
+ dtype=numpy.float32)]
else:
- points = numpy.array([
- self._plot.dataToPixel(x, y, axis='left', check=False)
- for (x, y) in zip(item['x'], item['y'])])
-
- # Draw the fill
- if (item['fill'] is not None and
- item['shape'] not in ('hline', 'vline')):
- self._progBase.use()
- gl.glUniformMatrix4fv(
- self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
-
- shape2D = FilledShape2D(
- points, style=item['fill'], color=item['color'])
- shape2D.render(
- posAttrib=self._progBase.attributes['position'],
- colorUnif=self._progBase.uniforms['color'],
- hatchStepUnif=self._progBase.uniforms['hatchStep'])
-
- # Draw the stroke
- if item['linestyle'] not in ('', ' ', None):
- if item['shape'] != 'polylines':
- # close the polyline
- points = numpy.append(points,
- numpy.atleast_2d(points[0]), axis=0)
-
- lines = GLLines2D(points[:, 0], points[:, 1],
- style=item['linestyle'],
- color=item['color'],
- dash2ndColor=item['linebgcolor'],
- width=item['linewidth'])
- lines.render(self.matScreenProj)
+ # Split sub-shapes at not finite values
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(item['x'])]))
+ subShapes = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:]):
+ if end > begin:
+ subShapes.append(numpy.array([
+ self._plotFrame.dataToPixel(x, y, axis='left')
+ for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])]))
+
+ for points in subShapes: # Draw each sub-shape
+ # Draw the fill
+ if (item['fill'] is not None and
+ item['shape'] not in ('hline', 'vline')):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ shape2D = glutils.FilledShape2D(
+ points, style=item['fill'], color=item['color'])
+ shape2D.render(
+ posAttrib=self._progBase.attributes['position'],
+ colorUnif=self._progBase.uniforms['color'],
+ hatchStepUnif=self._progBase.uniforms['hatchStep'])
+
+ # Draw the stroke
+ if item['linestyle'] not in ('', ' ', None):
+ if item['shape'] != 'polylines':
+ # close the polyline
+ points = numpy.append(points,
+ numpy.atleast_2d(points[0]), axis=0)
+
+ lines = glutils.GLLines2D(
+ points[:, 0], points[:, 1],
+ style=item['linestyle'],
+ color=item['color'],
+ dash2ndColor=item['linebgcolor'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
elif isinstance(item, _MarkerItem):
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
@@ -522,76 +527,103 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
continue
if xCoord is None or yCoord is None:
- pixelPos = self._plot.dataToPixel(
- xCoord, yCoord, axis=yAxis, check=False)
-
if xCoord is None: # Horizontal line in data space
+ pixelPos = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ yCoord,
+ axis=yAxis)
+
if item['text'] is not None:
x = self._plotFrame.size[0] - \
self._plotFrame.margins.right - pixelOffset
y = pixelPos[1] - pixelOffset
- label = Text2D(item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=RIGHT, valign=BOTTOM)
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=glutils.RIGHT,
+ valign=glutils.BOTTOM,
+ devicePixelRatio=self.getDevicePixelRatio())
labels.append(label)
width = self._plotFrame.size[0]
- lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
- lines.render(self.matScreenProj)
+ lines = glutils.GLLines2D(
+ (0, width), (pixelPos[1], pixelPos[1]),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
else: # yCoord is None: vertical line in data space
+ yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2]
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, 0.5 * sum(yRange), axis=yAxis)
+
if item['text'] is not None:
x = pixelPos[0] + pixelOffset
y = self._plotFrame.margins.top + pixelOffset
- label = Text2D(item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=LEFT, valign=TOP)
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=glutils.LEFT,
+ valign=glutils.TOP,
+ devicePixelRatio=self.getDevicePixelRatio())
labels.append(label)
height = self._plotFrame.size[1]
- lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
- lines.render(self.matScreenProj)
+ lines = glutils.GLLines2D(
+ (pixelPos[0], pixelPos[0]), (0, height),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
else:
- pixelPos = self._plot.dataToPixel(
- xCoord, yCoord, axis=yAxis, check=True)
- if pixelPos is None:
+ xmin, xmax = self._plot.getXAxis().getLimits()
+ ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits()
+ if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
# Do not render markers outside visible plot area
continue
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, yCoord, axis=yAxis)
if isYInverted:
- valign = BOTTOM
+ valign = glutils.BOTTOM
vPixelOffset = -pixelOffset
else:
- valign = TOP
+ valign = glutils.TOP
vPixelOffset = pixelOffset
if item['text'] is not None:
x = pixelPos[0] + pixelOffset
y = pixelPos[1] + vPixelOffset
- label = Text2D(item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=LEFT, valign=valign)
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=glutils.LEFT,
+ valign=valign,
+ devicePixelRatio=self.getDevicePixelRatio())
labels.append(label)
# For now simple implementation: using a curve for each marker
# Should pack all markers to a single set of points
- markerCurve = GLPlotCurve2D(
+ markerCurve = glutils.GLPlotCurve2D(
numpy.array((pixelPos[0],), dtype=numpy.float64),
numpy.array((pixelPos[1],), dtype=numpy.float64),
marker=item['symbol'],
markerColor=item['color'],
markerSize=11)
- markerCurve.render(self.matScreenProj, False, False)
+
+ context = glutils.RenderContext(
+ matrix=self.matScreenProj,
+ isXLog=False,
+ isYLog=False,
+ dpi=self.getDotsPerInch())
+ markerCurve.render(context)
markerCurve.discard()
else:
@@ -605,7 +637,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def _renderOverlayGL(self):
"""Render overlay layer: overlay items and crosshair."""
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ plotWidth, plotHeight = self._plotFrame.plotSize
# Scissor to plot area
gl.glScissor(self._plotFrame.margins.left,
@@ -658,7 +690,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
It renders the background, grid and items except overlays
"""
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ plotWidth, plotHeight = self._plotFrame.plotSize
gl.glScissor(self._plotFrame.margins.left,
self._plotFrame.margins.bottom,
@@ -687,9 +719,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
int(self.getDevicePixelRatio() * width),
int(self.getDevicePixelRatio() * height))
- self.matScreenProj = mat4Ortho(0, self._plotFrame.size[0],
- self._plotFrame.size[1], 0,
- 1, -1)
+ self.matScreenProj = glutils.mat4Ortho(
+ 0, self._plotFrame.size[0],
+ self._plotFrame.size[1], 0,
+ 1, -1)
# Store current ranges
previousXRange = self.getGraphXLimits()
@@ -824,21 +857,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
fillColor = None
if fill is True:
fillColor = color
- curve = GLPlotCurve2D(x, y, colorArray,
- xError=xerror,
- yError=yerror,
- lineStyle=linestyle,
- lineColor=color,
- lineWidth=linewidth,
- marker=symbol,
- markerColor=color,
- markerSize=symbolsize,
- fillColor=fillColor,
- baseline=baseline,
- isYLog=isYLog)
- curve.info = {
- 'yAxis': 'left' if yaxis is None else yaxis,
- }
+ curve = glutils.GLPlotCurve2D(
+ x, y, colorArray,
+ xError=xerror,
+ yError=yerror,
+ lineStyle=linestyle,
+ lineColor=color,
+ lineWidth=linewidth,
+ marker=symbol,
+ markerColor=color,
+ markerSize=symbolsize,
+ fillColor=fillColor,
+ baseline=baseline,
+ isYLog=isYLog)
+ curve.yaxis = 'left' if yaxis is None else yaxis
if yaxis == "right":
self._plotFrame.isY2Axis = True
@@ -853,7 +885,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if data.ndim == 2:
# Ensure array is contiguous and eventually convert its type
- if data.dtype in (numpy.float32, numpy.uint8, numpy.uint16):
+ dtypes = [dtype for dtype in (
+ numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
+ if glu.isSupportedGLType(dtype)]
+ if data.dtype in dtypes:
data = numpy.array(data, copy=False, order='C')
else:
_logger.info(
@@ -861,24 +896,27 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
data = numpy.array(data, dtype=numpy.float32, order='C')
normalization = colormap.getNormalization()
- if normalization in GLPlotColormap.SUPPORTED_NORMALIZATIONS:
+ if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
# Fast path applying colormap on the GPU
cmapRange = colormap.getColormapRange(data=data)
colormapLut = colormap.getNColors(nbColors=256)
gamma = colormap.getGammaNormalizationParameter()
-
- image = GLPlotColormap(data,
- origin,
- scale,
- colormapLut,
- normalization,
- gamma,
- cmapRange,
- alpha)
+ nanColor = colors.rgba(colormap.getNaNColor())
+
+ image = glutils.GLPlotColormap(
+ data,
+ origin,
+ scale,
+ colormapLut,
+ normalization,
+ gamma,
+ cmapRange,
+ alpha,
+ nanColor)
else: # Fallback applying colormap on CPU
rgba = colormap.applyToData(data)
- image = GLPlotRGBAImage(rgba, origin, scale, alpha)
+ image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha)
elif len(data.shape) == 3:
# For RGB, RGBA data
@@ -893,7 +931,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
else:
raise ValueError('Unsupported data type')
- image = GLPlotRGBAImage(data, origin, scale, alpha)
+ image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
else:
raise RuntimeError("Unsupported data shape {0}".format(data.shape))
@@ -916,7 +954,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if self._plotFrame.yAxis.isLog:
y = numpy.log10(y)
- triangles = GLPlotTriangles(x, y, color, triangles, alpha)
+ triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha)
return triangles
@@ -944,11 +982,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Remove methods
def remove(self, item):
- if isinstance(item, (GLPlotCurve2D,
- GLPlotColormap,
- GLPlotRGBAImage,
- GLPlotTriangles)):
- if isinstance(item, GLPlotCurve2D):
+ if isinstance(item, glutils.GLPlotItem):
+ if item.yaxis == 'right':
# Check if some curves remains on the right Y axis
y2AxisItems = (item for item in self._plot.getItems()
if isinstance(item, items.YAxisMixIn) and
@@ -997,13 +1032,18 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
_PICK_OFFSET = 3 # Offset in pixel used for picking
def _mouseInPlotArea(self, x, y):
- xPlot = numpy.clip(
- x, self._plotFrame.margins.left,
- self._plotFrame.size[0] - self._plotFrame.margins.right - 1)
- yPlot = numpy.clip(
- y, self._plotFrame.margins.top,
- self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1)
- return xPlot, yPlot
+ """Returns closest visible position in the plot.
+
+ This is performed in Qt widget pixel, not device pixel.
+
+ :param float x: X coordinate in Qt widget pixel
+ :param float y: Y coordinate in Qt widget pixel
+ :return: (x, y) closest point in the plot.
+ :rtype: List[float]
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ return (numpy.clip(x, left, left + width - 1), # TODO -1?
+ numpy.clip(y, top, top + height - 1))
def __pickCurves(self, item, x, y):
"""Perform picking on a curve item.
@@ -1016,22 +1056,26 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
"""
offset = self._PICK_OFFSET
if item.marker is not None:
- offset = max(item.markerSize / 2., offset)
+ # Convert markerSize from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ size = item.markerSize / 72. * qtDpi
+ offset = max(size / 2., offset)
if item.lineStyle is not None:
- offset = max(item.lineWidth / 2., offset)
-
- yAxis = item.info['yAxis']
+ # Convert line width from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ lineWidth = item.lineWidth / 72. * qtDpi
+ offset = max(lineWidth / 2., offset)
inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=yAxis, check=True)
+ axis=item.yaxis, check=True)
if dataPos is None:
return None
xPick0, yPick0 = dataPos
inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=yAxis, check=True)
+ axis=item.yaxis, check=True)
if dataPos is None:
return None
xPick1, yPick1 = dataPos
@@ -1051,8 +1095,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
xPickMin = numpy.log10(xPickMin)
xPickMax = numpy.log10(xPickMax)
- if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or (
- yAxis == 'right' and self._plotFrame.y2Axis.isLog):
+ if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
+ item.yaxis == 'right' and self._plotFrame.y2Axis.isLog):
yPickMin = numpy.log10(yPickMin)
yPickMax = numpy.log10(yPickMax)
@@ -1060,6 +1104,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
xPickMax, yPickMax)
def pickItem(self, x, y, item):
+ # Picking is performed in Qt widget pixels not device pixels
dataPos = self._plot.pixelToData(x, y, axis='left', check=True)
if dataPos is None:
return None # Outside plot area
@@ -1100,17 +1145,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return (0,) if isPicked else None
# Pick image, curve, triangles
- elif isinstance(item, (GLPlotCurve2D,
- GLPlotColormap,
- GLPlotRGBAImage,
- GLPlotTriangles)):
- if isinstance(item, (GLPlotColormap, GLPlotRGBAImage, GLPlotTriangles)):
- return item.pick(*dataPos) # Might be None
-
- elif isinstance(item, GLPlotCurve2D):
+ elif isinstance(item, glutils.GLPlotItem):
+ if isinstance(item, glutils.GLPlotCurve2D):
return self.__pickCurves(item, x, y)
else:
- return None
+ return item.pick(*dataPos) # Might be None
# Update curve
@@ -1184,8 +1223,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if axis == 'left':
self._plotFrame.yAxis.title = label
else: # right axis
- if label:
- _logger.warning('Right axis label not implemented')
+ self._plotFrame.y2Axis.title = label
# Graph limits
@@ -1209,7 +1247,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
:param str keepDim: The dimension to maintain: 'x', 'y' or None.
If None (the default), the dimension with the largest range.
"""
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ plotWidth, plotHeight = self._plotFrame.plotSize
if plotWidth <= 2 or plotHeight <= 2:
return
@@ -1352,17 +1390,25 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Data <-> Pixel coordinates conversion
def dataToPixel(self, x, y, axis):
- return self._plotFrame.dataToPixel(x, y, axis)
+ result = self._plotFrame.dataToPixel(x, y, axis)
+ if result is None:
+ return None
+ else:
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(value/devicePixelRatio for value in result)
def pixelToData(self, x, y, axis):
- return self._plotFrame.pixelToData(x, y, axis)
+ devicePixelRatio = self.getDevicePixelRatio()
+ return self._plotFrame.pixelToData(
+ x * devicePixelRatio, y * devicePixelRatio, axis)
def getPlotBoundsInPixels(self):
- return self._plotFrame.plotOrigin + self._plotFrame.plotSize
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(int(value / devicePixelRatio)
+ for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize)
- def setAxesDisplayed(self, displayed):
- BackendBase.BackendBase.setAxesDisplayed(self, displayed)
- self._plotFrame.displayed = displayed
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ self._plotFrame.marginRatios = left, top, right, bottom
def setForegroundColors(self, foregroundColor, gridColor):
self._plotFrame.foregroundColor = foregroundColor
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
index 9ab85fd..c4e2c1e 100644
--- a/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -43,6 +43,7 @@ from silx.math.combo import min_max
from ...._glutils import gl
from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
+from .GLPlotImage import GLPlotItem
_logger = logging.getLogger(__name__)
@@ -172,10 +173,10 @@ class _Fill2D(object):
self._xFillVboData, self._yFillVboData = vertexBuffer(points.T)
- def render(self, matrix):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
+ :param RenderContext context:
"""
self.prepare()
@@ -186,7 +187,7 @@ class _Fill2D(object):
gl.glUniformMatrix4fv(
self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
- numpy.dot(matrix,
+ numpy.dot(context.matrix,
mat4Translate(*self.offset)).astype(numpy.float32))
gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color)
@@ -404,11 +405,13 @@ class GLLines2D(object):
"""OpenGL context initialization"""
gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
- def render(self, matrix):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
+ :param RenderContext context:
"""
+ width = self.width / 72. * context.dpi
+
style = self.style
if style is None:
return
@@ -425,7 +428,7 @@ class GLLines2D(object):
gl.glUniform2f(program.uniforms['halfViewportSize'],
0.5 * viewWidth, 0.5 * viewHeight)
- dashPeriod = self.dashPeriod * self.width
+ dashPeriod = self.dashPeriod * width
if self.style == DOTTED:
dash = (0.2 * dashPeriod,
0.5 * dashPeriod,
@@ -463,10 +466,10 @@ class GLLines2D(object):
0,
self.distVboData)
- if self.width != 1:
+ if width != 1:
gl.glEnable(gl.GL_LINE_SMOOTH)
- matrix = numpy.dot(matrix,
+ matrix = numpy.dot(context.matrix,
mat4Translate(*self.offset)).astype(numpy.float32)
gl.glUniformMatrix4fv(program.uniforms['matrix'],
1, gl.GL_TRUE, matrix)
@@ -503,7 +506,7 @@ class GLLines2D(object):
0,
self.yVboData)
- gl.glLineWidth(self.width)
+ gl.glLineWidth(width)
gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
gl.glDisable(gl.GL_LINE_SMOOTH)
@@ -516,10 +519,26 @@ def distancesFromArrays(xData, yData):
:param numpy.ndarray yData: Y coordinate of points
:rtype: numpy.ndarray
"""
- deltas = numpy.dstack((
- numpy.ediff1d(xData, to_begin=numpy.float32(0.)),
- numpy.ediff1d(yData, to_begin=numpy.float32(0.))))[0]
- return numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1)))
+ # Split array into sub-shapes at not finite points
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(xData), numpy.isfinite(yData))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
+
+ # Compute distance independently for each sub-shapes,
+ # putting not finite points as last points of sub-shapes
+ distances = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
+ if begin == end: # Empty shape
+ continue
+ elif end - begin == 1: # Single element
+ distances.append([0])
+ else:
+ deltas = numpy.dstack((
+ numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)),
+ numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0]
+ distances.append(
+ numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))))
+ return numpy.concatenate(distances)
# points ######################################################################
@@ -833,10 +852,10 @@ class _Points2D(object):
if majorVersion >= 3: # OpenGL 3
gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
- def render(self, matrix):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
+ :param RenderContext context:
"""
if self.marker is None:
return
@@ -844,7 +863,7 @@ class _Points2D(object):
program = self._getProgram(self.marker)
program.use()
- matrix = numpy.dot(matrix,
+ matrix = numpy.dot(context.matrix,
mat4Translate(*self.offset)).astype(numpy.float32)
gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
@@ -854,6 +873,13 @@ class _Points2D(object):
size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
else:
size = self.size
+ size = size / 72. * context.dpi
+
+ if self.marker in (PLUS, H_LINE, V_LINE,
+ TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN):
+ # Convert to nearest odd number
+ size = size // 2 * 2 + 1.
+
gl.glUniform1f(program.uniforms['size'], size)
# gl.glPointSize(self.size)
@@ -1021,17 +1047,17 @@ class _ErrorBars(object):
self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
yAttrib.size // 2)
- def render(self, matrix):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
+ :param RenderContext context:
"""
self.prepare()
if self._attribs is not None:
- self._lines.render(matrix)
- self._xErrPoints.render(matrix)
- self._yErrPoints.render(matrix)
+ self._lines.render(context)
+ self._xErrPoints.render(context)
+ self._yErrPoints.render(context)
def discard(self):
"""Release VBOs"""
@@ -1067,7 +1093,7 @@ def _proxyProperty(*componentsAttributes):
return property(getter, setter)
-class GLPlotCurve2D(object):
+class GLPlotCurve2D(GLPlotItem):
def __init__(self, xData, yData, colorData=None,
xError=None, yError=None,
lineStyle=SOLID,
@@ -1080,7 +1106,7 @@ class GLPlotCurve2D(object):
fillColor=None,
baseline=None,
isYLog=False):
-
+ super().__init__()
self.colorData = colorData
# Compute x bounds
@@ -1220,19 +1246,17 @@ class GLPlotCurve2D(object):
self.colorVboData = cAttrib
self.useColorVboData = cAttrib is not None
- def render(self, matrix, isXLog, isYLog):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
- :param bool isXLog:
- :param bool isYLog:
+ :param RenderContext context: Rendering information
"""
self.prepare()
if self.fill is not None:
- self.fill.render(matrix)
- self._errorBars.render(matrix)
- self.lines.render(matrix)
- self.points.render(matrix)
+ self.fill.render(context)
+ self._errorBars.render(context)
+ self.lines.render(context)
+ self.points.render(context)
def discard(self):
"""Release VBOs"""
diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py
index 43f6e10..c5ee75b 100644
--- a/silx/gui/plot/backends/glutils/GLPlotFrame.py
+++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -61,7 +61,7 @@ class PlotAxis(object):
This class is intended to be used with :class:`GLPlotFrame`.
"""
- def __init__(self, plot,
+ def __init__(self, plotFrame,
tickLength=(0., 0.),
foregroundColor=(0., 0., 0., 1.0),
labelAlign=CENTER, labelVAlign=CENTER,
@@ -69,7 +69,7 @@ class PlotAxis(object):
titleRotate=0, titleOffset=(0., 0.)):
self._ticks = None
- self._plot = weakref.ref(plot)
+ self._plotFrameRef = weakref.ref(plotFrame)
self._isDateTime = False
self._timeZone = None
@@ -157,6 +157,12 @@ class PlotAxis(object):
self._dirtyTicks()
@property
+ def devicePixelRatio(self):
+ """Returns the ratio between qt pixels and device pixels."""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.devicePixelRatio if plotFrame is not None else 1.
+
+ @property
def title(self):
"""The text label associated with this axis as a str in latin-1."""
return self._title
@@ -165,10 +171,18 @@ class PlotAxis(object):
def title(self, title):
if title != self._title:
self._title = title
+ self._dirtyPlotFrame()
- plot = self._plot()
- if plot is not None:
- plot._dirty()
+ @property
+ def titleOffset(self):
+ """Title offset in pixels (x: int, y: int)"""
+ return self._titleOffset
+
+ @titleOffset.setter
+ def titleOffset(self, offset):
+ if offset != self._titleOffset:
+ self._titleOffset = offset
+ self._dirtyTicks()
@property
def foregroundColor(self):
@@ -201,6 +215,8 @@ class PlotAxis(object):
tickLabelsSize = [0., 0.]
xTickLength, yTickLength = self._tickLength
+ xTickLength *= self.devicePixelRatio
+ yTickLength *= self.devicePixelRatio
for (xPixel, yPixel), dataPos, text in self.ticks:
if text is None:
tickScale = 0.5
@@ -212,7 +228,8 @@ class PlotAxis(object):
x=xPixel - xTickLength,
y=yPixel - yTickLength,
align=self._labelAlign,
- valign=self._labelVAlign)
+ valign=self._labelVAlign,
+ devicePixelRatio=self.devicePixelRatio)
width, height = label.size
if width > tickLabelsSize[0]:
@@ -230,7 +247,7 @@ class PlotAxis(object):
xAxisCenter = 0.5 * (x0 + x1)
yAxisCenter = 0.5 * (y0 + y1)
- xOffset, yOffset = self._titleOffset
+ xOffset, yOffset = self.titleOffset
# Adaptative title positioning:
# tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
@@ -245,17 +262,22 @@ class PlotAxis(object):
y=yAxisCenter + yOffset,
align=self._titleAlign,
valign=self._titleVAlign,
- rotate=self._titleRotate)
+ rotate=self._titleRotate,
+ devicePixelRatio=self.devicePixelRatio)
labels.append(axisTitle)
return vertices, labels
+ def _dirtyPlotFrame(self):
+ """Dirty parent GLPlotFrame"""
+ plotFrame = self._plotFrameRef()
+ if plotFrame is not None:
+ plotFrame._dirty()
+
def _dirtyTicks(self):
"""Mark ticks as dirty and notify listener (i.e., background)."""
self._ticks = None
- plot = self._plot()
- if plot is not None:
- plot._dirty()
+ self._dirtyPlotFrame()
@staticmethod
def _frange(start, stop, step):
@@ -314,7 +336,7 @@ class PlotAxis(object):
xScale = (x1 - x0) / (dataMax - dataMin)
yScale = (y1 - y0) / (dataMax - dataMin)
- nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2))
+ nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
# Density of 1.3 label per 92 pixels
# i.e., 1.3 label per inch on a 92 dpi screen
@@ -391,11 +413,11 @@ class GLPlotFrame(object):
# Margins used when plot frame is not displayed
_NoDisplayMargins = _Margins(0, 0, 0, 0)
- def __init__(self, margins, foregroundColor, gridColor):
+ def __init__(self, marginRatios, foregroundColor, gridColor):
"""
- :param margins: The margins around plot area for axis and labels.
- :type margins: dict with 'left', 'right', 'top', 'bottom' keys and
- values as ints.
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
:param foregroundColor: color used for the frame and labels.
:type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
:param gridColor: color used for grid lines.
@@ -403,7 +425,9 @@ class GLPlotFrame(object):
"""
self._renderResources = None
- self._margins = self._Margins(**margins)
+ self.__marginRatios = marginRatios
+ self.__marginsCache = None
+
self._foregroundColor = foregroundColor
self._gridColor = gridColor
@@ -412,7 +436,8 @@ class GLPlotFrame(object):
self._grid = False
self._size = 0., 0.
self._title = ''
- self._displayed = True
+
+ self._devicePixelRatio = 1.
@property
def isDirty(self):
@@ -453,26 +478,49 @@ class GLPlotFrame(object):
if self._gridColor != color:
self._gridColor = color
self._dirty()
-
+
@property
- def displayed(self):
- """Whether axes and their labels are displayed or not (bool)"""
- return self._displayed
-
- @displayed.setter
- def displayed(self, displayed):
- displayed = bool(displayed)
- if displayed != self._displayed:
- self._displayed = displayed
+ def marginRatios(self):
+ """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].
+ """
+ return self.__marginRatios
+
+ @marginRatios.setter
+ def marginRatios(self, ratios):
+ ratios = tuple(float(v) for v in ratios)
+ assert len(ratios) == 4
+ for value in ratios:
+ assert 0. <= value <= 1.
+ assert ratios[0] + ratios[2] < 1.
+ assert ratios[1] + ratios[3] < 1.
+
+ if self.__marginRatios != ratios:
+ self.__marginRatios = ratios
+ self.__marginsCache = None # Clear cached margins
self._dirty()
@property
def margins(self):
"""Margins in pixels around the plot."""
- if not self.displayed:
- return self._NoDisplayMargins
- else:
- return self._margins
+ if self.__marginsCache is None:
+ width, height = self.size
+ left, top, right, bottom = self.marginRatios
+ self.__marginsCache = self._Margins(
+ left=int(left*width),
+ right=int(right*width),
+ top=int(top*height),
+ bottom=int(bottom*height))
+ return self.__marginsCache
+
+ @property
+ def devicePixelRatio(self):
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ if ratio != self._devicePixelRatio:
+ self._devicePixelRatio = ratio
+ self._dirty()
@property
def grid(self):
@@ -493,7 +541,7 @@ class GLPlotFrame(object):
@property
def size(self):
- """Size in pixels of the plot area including margins."""
+ """Size in device pixels of the plot area including margins."""
return self._size
@size.setter
@@ -502,6 +550,7 @@ class GLPlotFrame(object):
size = tuple(size)
if size != self._size:
self._size = size
+ self.__marginsCache = None # Clear cached margins
self._dirty()
@property
@@ -580,7 +629,8 @@ class GLPlotFrame(object):
x=xTitle,
y=yTitle,
align=CENTER,
- valign=BOTTOM))
+ valign=BOTTOM,
+ devicePixelRatio=self.devicePixelRatio))
# grid
gridVertices = numpy.array(self._buildGridVertices(),
@@ -592,7 +642,7 @@ class GLPlotFrame(object):
_SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
def render(self):
- if not self.displayed:
+ if self.margins == self._NoDisplayMargins:
return
if self._renderResources is None:
@@ -661,25 +711,24 @@ class GLPlotFrame(object):
# GLPlotFrame2D ###############################################################
class GLPlotFrame2D(GLPlotFrame):
- def __init__(self, margins, foregroundColor, gridColor):
+ def __init__(self, marginRatios, foregroundColor, gridColor):
"""
- :param margins: The margins around plot area for axis and labels.
- :type margins: dict with 'left', 'right', 'top', 'bottom' keys and
- values as ints.
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
:param foregroundColor: color used for the frame and labels.
:type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
:param gridColor: color used for grid lines.
:type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
"""
- super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor)
+ super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor)
self.axes.append(PlotAxis(self,
tickLength=(0., -5.),
foregroundColor=self._foregroundColor,
labelAlign=CENTER, labelVAlign=TOP,
titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=0,
- titleOffset=(0, self.margins.bottom // 2)))
+ titleRotate=0))
self._x2AxisCoords = ()
@@ -688,18 +737,14 @@ class GLPlotFrame2D(GLPlotFrame):
foregroundColor=self._foregroundColor,
labelAlign=RIGHT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=BOTTOM,
- titleRotate=ROTATE_270,
- titleOffset=(-3 * self.margins.left // 4,
- 0)))
+ titleRotate=ROTATE_270))
self._y2Axis = PlotAxis(self,
tickLength=(-5., 0.),
foregroundColor=self._foregroundColor,
labelAlign=LEFT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=ROTATE_270,
- titleOffset=(3 * self.margins.right // 4,
- 0))
+ titleRotate=ROTATE_270)
self._isYAxisInverted = False
@@ -794,6 +839,24 @@ class GLPlotFrame2D(GLPlotFrame):
self._baseVectors = vectors
self._dirty()
+ def _updateTitleOffset(self):
+ """Update axes title offset according to margins"""
+ margins = self.margins
+ self.xAxis.titleOffset = 0, margins.bottom // 2
+ self.yAxis.titleOffset = -3 * margins.left // 4, 0
+ self.y2Axis.titleOffset = 3 * margins.right // 4, 0
+
+ # Override size and marginRatios setters to update titleOffsets
+ @GLPlotFrame.size.setter
+ def size(self, size):
+ GLPlotFrame.size.fset(self, size)
+ self._updateTitleOffset()
+
+ @GLPlotFrame.marginRatios.setter
+ def marginRatios(self, ratios):
+ GLPlotFrame.marginRatios.fset(self, ratios)
+ self._updateTitleOffset()
+
@property
def dataRanges(self):
"""Ranges of data visible in the plot on x, y and y2 axes.
diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py
index e985a3d..f60a159 100644
--- a/silx/gui/plot/backends/glutils/GLPlotImage.py
+++ b/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -40,10 +40,12 @@ from ...._glutils import gl, Program, Texture
from ..._utils import FLOAT32_MINPOS
from .GLSupport import mat4Translate, mat4Scale
from .GLTexture import Image
+from .GLPlotItem import GLPlotItem
-class _GLPlotData2D(object):
+class _GLPlotData2D(GLPlotItem):
def __init__(self, data, origin, scale):
+ super().__init__()
self.data = data
assert len(origin) == 2
self.origin = tuple(origin)
@@ -80,15 +82,6 @@ class _GLPlotData2D(object):
oy, sy = self.origin[1], self.scale[1]
return oy + sy * self.data.shape[0] if sy >= 0. else oy
- def discard(self):
- pass
-
- def prepare(self):
- pass
-
- def render(self, matrix, isXLog, isYLog):
- pass
-
class GLPlotColormap(_GLPlotData2D):
@@ -160,6 +153,11 @@ class GLPlotColormap(_GLPlotData2D):
'fragment': """
#version 120
+ /* isnan declaration for compatibility with GLSL 1.20 */
+ bool isnan(float value) {
+ return (value != value);
+ }
+
uniform sampler2D data;
uniform sampler2D cmap_texture;
uniform int cmap_normalization;
@@ -167,6 +165,7 @@ class GLPlotColormap(_GLPlotData2D):
uniform float cmap_min;
uniform float cmap_oneOverRange;
uniform float alpha;
+ uniform vec4 nancolor;
varying vec2 coords;
@@ -175,7 +174,8 @@ class GLPlotColormap(_GLPlotData2D):
const float oneOverLog10 = 0.43429448190325176;
void main(void) {
- float value = texture2D(data, textureCoords()).r;
+ float data = texture2D(data, textureCoords()).r;
+ float value = data;
if (cmap_normalization == 1) { /*Logarithm mapping*/
if (value > 0.) {
value = clamp(cmap_oneOverRange *
@@ -202,7 +202,11 @@ class GLPlotColormap(_GLPlotData2D):
value = clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.);
}
- gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5));
+ if (isnan(data)) {
+ gl_FragColor = nancolor;
+ } else {
+ gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5));
+ }
gl_FragColor.a *= alpha;
}
"""
@@ -213,6 +217,7 @@ class GLPlotColormap(_GLPlotData2D):
_INTERNAL_FORMATS = {
numpy.dtype(numpy.float32): gl.GL_R32F,
+ numpy.dtype(numpy.float16): gl.GL_R16F,
# Use normalized integer for unsigned int formats
numpy.dtype(numpy.uint16): gl.GL_R16,
numpy.dtype(numpy.uint8): gl.GL_R8,
@@ -232,7 +237,7 @@ class GLPlotColormap(_GLPlotData2D):
def __init__(self, data, origin, scale,
colormap, normalization='linear', gamma=0., cmapRange=None,
- alpha=1.0):
+ alpha=1.0, nancolor=(1., 1., 1., 0.)):
"""Create a 2D colormap
:param data: The 2D scalar data array to display
@@ -252,6 +257,8 @@ class GLPlotColormap(_GLPlotData2D):
TODO: check consistency with matplotlib
:type cmapRange: (float, float) or None
:param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ :param nancolor: RGBA color for Not-A-Number values
+ :type nancolor: 4-tuple of float in [0., 1.]
"""
assert data.dtype in self._INTERNAL_FORMATS
assert normalization in self.SUPPORTED_NORMALIZATIONS
@@ -263,6 +270,7 @@ class GLPlotColormap(_GLPlotData2D):
self._cmapRange = (1., 10.) # Colormap range
self.cmapRange = cmapRange # Update _cmapRange
self._alpha = numpy.clip(alpha, 0., 1.)
+ self._nancolor = numpy.clip(nancolor, 0., 1.)
self._cmap_texture = None
self._texture = None
@@ -283,7 +291,7 @@ class GLPlotColormap(_GLPlotData2D):
if self.normalization == 'log':
assert self._cmapRange[0] > 0. and self._cmapRange[1] > 0.
elif self.normalization == 'sqrt':
- assert self._cmapRange[0] >= 0. and self._cmapRange[1] > 0.
+ assert self._cmapRange[0] >= 0. and self._cmapRange[1] >= 0.
return self._cmapRange
@cmapRange.setter
@@ -324,6 +332,7 @@ class GLPlotColormap(_GLPlotData2D):
magFilter=gl.GL_NEAREST,
wrap=(gl.GL_CLAMP_TO_EDGE,
gl.GL_CLAMP_TO_EDGE))
+ self._cmap_texture.prepare()
if self._texture is None:
internalFormat = self._INTERNAL_FORMATS[self.data.dtype]
@@ -376,9 +385,15 @@ class GLPlotColormap(_GLPlotData2D):
oneOverRange = 0. # Fall-back
gl.glUniform1f(prog.uniforms['cmap_oneOverRange'], oneOverRange)
+ gl.glUniform4f(prog.uniforms['nancolor'], *self._nancolor)
+
self._cmap_texture.bind()
- def _renderLinear(self, matrix):
+ def _renderLinear(self, context):
+ """Perform rendering when both axes have linear scales
+
+ :param RenderContext context: Rendering information
+ """
self.prepare()
prog = self._linearProgram
@@ -386,7 +401,7 @@ class GLPlotColormap(_GLPlotData2D):
gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
- mat = numpy.dot(numpy.dot(matrix,
+ mat = numpy.dot(numpy.dot(context.matrix,
mat4Translate(*self.origin)),
mat4Scale(*self.scale))
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
@@ -400,10 +415,14 @@ class GLPlotColormap(_GLPlotData2D):
prog.attributes['texCoords'],
self._DATA_TEX_UNIT)
- def _renderLog10(self, matrix, isXLog, isYLog):
+ def _renderLog10(self, context):
+ """Perform rendering when one axis has log scale
+
+ :param RenderContext context: Rendering information
+ """
xMin, yMin = self.xMin, self.yMin
- if ((isXLog and xMin < FLOAT32_MINPOS) or
- (isYLog and yMin < FLOAT32_MINPOS)):
+ if ((context.isXLog and xMin < FLOAT32_MINPOS) or
+ (context.isYLog and yMin < FLOAT32_MINPOS)):
# Do not render images that are partly or totally <= 0
return
@@ -417,12 +436,12 @@ class GLPlotColormap(_GLPlotData2D):
gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matrix.astype(numpy.float32))
+ context.matrix.astype(numpy.float32))
mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
mat.astype(numpy.float32))
- gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog)
+ gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
ex = ox + self.scale[0] * self.data.shape[1]
ey = oy + self.scale[1] * self.data.shape[0]
@@ -461,11 +480,15 @@ class GLPlotColormap(_GLPlotData2D):
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
- def render(self, matrix, isXLog, isYLog):
- if any((isXLog, isYLog)):
- self._renderLog10(matrix, isXLog, isYLog)
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog10(context)
else:
- self._renderLinear(matrix)
+ self._renderLinear(context)
# Unbind colormap texture
gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit)
@@ -635,7 +658,11 @@ class GLPlotRGBAImage(_GLPlotData2D):
format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB
self._texture.updateAll(format_=format_, data=self.data)
- def _renderLinear(self, matrix):
+ def _renderLinear(self, context):
+ """Perform rendering with both axes having linear scales
+
+ :param RenderContext context: Rendering information
+ """
self.prepare()
prog = self._linearProgram
@@ -643,7 +670,7 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
- mat = numpy.dot(numpy.dot(matrix, mat4Translate(*self.origin)),
+ mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)),
mat4Scale(*self.scale))
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
mat.astype(numpy.float32))
@@ -654,7 +681,11 @@ class GLPlotRGBAImage(_GLPlotData2D):
prog.attributes['texCoords'],
self._DATA_TEX_UNIT)
- def _renderLog(self, matrix, isXLog, isYLog):
+ def _renderLog(self, context):
+ """Perform rendering with axes having log scale
+
+ :param RenderContext context: Rendering information
+ """
self.prepare()
prog = self._logProgram
@@ -665,12 +696,12 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matrix.astype(numpy.float32))
+ context.matrix.astype(numpy.float32))
mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
mat.astype(numpy.float32))
- gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog)
+ gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
@@ -707,8 +738,12 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
- def render(self, matrix, isXLog, isYLog):
- if any((isXLog, isYLog)):
- self._renderLog(matrix, isXLog, isYLog)
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog(context)
else:
- self._renderLinear(matrix)
+ self._renderLinear(context)
diff --git a/silx/gui/plot/backends/glutils/GLPlotItem.py b/silx/gui/plot/backends/glutils/GLPlotItem.py
new file mode 100644
index 0000000..899f38e
--- /dev/null
+++ b/silx/gui/plot/backends/glutils/GLPlotItem.py
@@ -0,0 +1,94 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a base class for PlotWidget OpenGL backend primitives
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/07/2020"
+
+
+class RenderContext:
+ """Context with which to perform OpenGL rendering.
+
+ :param numpy.ndarray matrix: 4x4 transform matrix to use for rendering
+ :param bool isXLog: Whether X axis is log scale or not
+ :param bool isYLog: Whether Y axis is log scale or not
+ :param float dpi: Number of device pixels per inch
+ """
+
+ def __init__(self, matrix=None, isXLog=False, isYLog=False, dpi=96.):
+ self.matrix = matrix
+ """Current transformation matrix"""
+
+ self.__isXLog = isXLog
+ self.__isYLog = isYLog
+ self.__dpi = dpi
+
+ @property
+ def isXLog(self):
+ """True if X axis is using log scale"""
+ return self.__isXLog
+
+ @property
+ def isYLog(self):
+ """True if Y axis is using log scale"""
+ return self.__isYLog
+
+ @property
+ def dpi(self):
+ """Number of device pixels per inch"""
+ return self.__dpi
+
+
+class GLPlotItem:
+ """Base class for primitives used in the PlotWidget OpenGL backend"""
+
+ def __init__(self):
+ self.yaxis = 'left'
+ "YAxis this item is attached to (either 'left' or 'right')"
+
+ def pick(self, x, y):
+ """Perform picking at given position.
+
+ :param float x: X coordinate in plot data frame of reference
+ :param float y: Y coordinate in plot data frame of reference
+ :returns:
+ Result of picking as a list of indices or None if nothing picked
+ :rtype: Union[List[int],None]
+ """
+ return None
+
+ def render(self, context):
+ """Performs OpenGL rendering of the item.
+
+ :param RenderContext context: Rendering context information
+ """
+ pass
+
+ def discard(self):
+ """Discards OpenGL resources this item has created."""
+ pass
diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/silx/gui/plot/backends/glutils/GLPlotTriangles.py
index 7aeb5ab..d5ba1a6 100644
--- a/silx/gui/plot/backends/glutils/GLPlotTriangles.py
+++ b/silx/gui/plot/backends/glutils/GLPlotTriangles.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,9 +38,10 @@ import numpy
from .....math.combo import min_max
from .... import _glutils as glutils
from ...._glutils import gl
+from .GLPlotItem import GLPlotItem
-class GLPlotTriangles(object):
+class GLPlotTriangles(GLPlotItem):
"""Handle rendering of a set of colored triangles"""
_PROGRAM = glutils.Program(
@@ -81,6 +82,7 @@ class GLPlotTriangles(object):
:param numpy.ndarray triangles: (N, 3) array of indices of triangles
:param float alpha: Opacity in [0, 1]
"""
+ super().__init__()
# Check and convert input data
x = numpy.ravel(numpy.array(x, dtype=numpy.float32))
y = numpy.ravel(numpy.array(y, dtype=numpy.float32))
@@ -161,12 +163,10 @@ class GLPlotTriangles(object):
usage=gl.GL_STATIC_DRAW,
target=gl.GL_ELEMENT_ARRAY_BUFFER)
- def render(self, matrix, isXLog, isYLog):
+ def render(self, context):
"""Perform rendering
- :param numpy.ndarray matrix: 4x4 transform matrix to use
- :param bool isXLog:
- :param bool isYLog:
+ :param RenderContext context: Rendering information
"""
self.prepare()
@@ -178,7 +178,7 @@ class GLPlotTriangles(object):
gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'],
1,
gl.GL_TRUE,
- matrix.astype(numpy.float32))
+ context.matrix.astype(numpy.float32))
gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha)
diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py
index 725c12c..d6ae6fa 100644
--- a/silx/gui/plot/backends/glutils/GLText.py
+++ b/silx/gui/plot/backends/glutils/GLText.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -140,7 +140,9 @@ class Text2D(object):
color=(0., 0., 0., 1.),
bgColor=None,
align=LEFT, valign=BASELINE,
- rotate=0):
+ rotate=0,
+ devicePixelRatio= 1.):
+ self.devicePixelRatio = devicePixelRatio
self._vertices = None
self._text = text
self.x = x
@@ -160,30 +162,35 @@ class Text2D(object):
self._rotate = numpy.radians(rotate)
- def _getTexture(self, text):
+ def _getTexture(self, text, devicePixelRatio):
# Retrieve/initialize texture cache for current context
+ textureKey = text, devicePixelRatio
+
context = Context.getCurrent()
if context not in self._textures:
self._textures[context] = _Cache(
callback=lambda key, value: value[0].discard())
textures = self._textures[context]
- if text not in textures:
- image, offset = font.rasterText(text,
- font.getDefaultFontFamily())
- if text not in self._sizes:
- self._sizes[text] = image.shape[1], image.shape[0]
-
- textures[text] = (
- Texture(gl.GL_RED,
- data=image,
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=(gl.GL_CLAMP_TO_EDGE,
- gl.GL_CLAMP_TO_EDGE)),
- offset)
-
- return textures[text]
+ if textureKey not in textures:
+ image, offset = font.rasterText(
+ text,
+ font.getDefaultFontFamily(),
+ devicePixelRatio=self.devicePixelRatio)
+ if textureKey not in self._sizes:
+ self._sizes[textureKey] = image.shape[1], image.shape[0]
+
+ texture = Texture(
+ gl.GL_RED,
+ data=image,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ texture.prepare()
+ textures[textureKey] = texture, offset
+
+ return textures[textureKey]
@property
def text(self):
@@ -191,11 +198,14 @@ class Text2D(object):
@property
def size(self):
- if self.text not in self._sizes:
- image, offset = font.rasterText(self.text,
- font.getDefaultFontFamily())
- self._sizes[self.text] = image.shape[1], image.shape[0]
- return self._sizes[self.text]
+ textureKey = self.text, self.devicePixelRatio
+ if textureKey not in self._sizes:
+ image, offset = font.rasterText(
+ self.text,
+ font.getDefaultFontFamily(),
+ devicePixelRatio=self.devicePixelRatio)
+ self._sizes[textureKey] = image.shape[1], image.shape[0]
+ return self._sizes[textureKey]
def getVertices(self, offset, shape):
height, width = shape
@@ -238,7 +248,7 @@ class Text2D(object):
prog.use()
texUnit = 0
- texture, offset = self._getTexture(self.text)
+ texture, offset = self._getTexture(self.text, self.devicePixelRatio)
gl.glUniform1i(prog.uniforms['texText'], texUnit)
diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py
index 118a36f..37fbdd0 100644
--- a/silx/gui/plot/backends/glutils/GLTexture.py
+++ b/silx/gui/plot/backends/glutils/GLTexture.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -98,6 +98,7 @@ class Image(object):
minFilter=self._MIN_FILTER,
magFilter=self._MAG_FILTER,
wrap=self._WRAP)
+ texture.prepare()
vertices = numpy.array((
(0., 0., 0., 0.),
(self.width, 0., 1., 0.),
@@ -177,6 +178,7 @@ class Image(object):
(xOrig, yOrig + hData, 0., vMax),
(xOrig + wData, yOrig + hData, uMax, vMax)),
dtype=numpy.float32)
+ texture.prepare()
tiles.append((texture, vertices,
{'xOrigData': xOrig, 'yOrigData': yOrig,
'wData': wData, 'hData': hData}))
@@ -203,6 +205,7 @@ class Image(object):
texture.update(format_,
data[yOrig:yOrig+height, xOrig:xOrig+width],
texUnit=texUnit)
+ texture.prepare()
# TODO check
# width=info['wData'], height=info['hData'],
# texUnit=texUnit, unpackAlign=unpackAlign,
diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py
index d58c084..f87d7c1 100644
--- a/silx/gui/plot/backends/glutils/__init__.py
+++ b/silx/gui/plot/backends/glutils/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -39,6 +39,7 @@ _logger = logging.getLogger(__name__)
from .GLPlotCurve import * # noqa
from .GLPlotFrame import * # noqa
from .GLPlotImage import * # noqa
+from .GLPlotItem import GLPlotItem, RenderContext # noqa
from .GLPlotTriangles import GLPlotTriangles # noqa
from .GLSupport import * # noqa
from .GLText import * # noqa
diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py
index 4d4eac0..0484025 100644
--- a/silx/gui/plot/items/__init__.py
+++ b/silx/gui/plot/items/__init__.py
@@ -32,7 +32,8 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "22/06/2017"
-from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
+from .core import (Item, DataItem, # noqa
+ LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa
ComplexMixIn, ItemChangedType, PointsBase) # noqa
diff --git a/silx/gui/plot/items/_arc_roi.py b/silx/gui/plot/items/_arc_roi.py
new file mode 100644
index 0000000..a22cc3d
--- /dev/null
+++ b/silx/gui/plot/items/_arc_roi.py
@@ -0,0 +1,873 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Arc ROI item for the :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import numpy
+
+from ... import utils
+from .. import items
+from ...colors import rgba
+from ....utils.proxy import docstring
+from ._roi_base import HandleBasedROI
+from ._roi_base import InteractionModeMixIn
+from ._roi_base import RoiInteractionMode
+
+
+class _ArcGeometry:
+ """
+ Non-mutable object to store the geometry of the arc ROI.
+
+ The aim is is to switch between consistent state without dealing with
+ intermediate values.
+ """
+ def __init__(self, center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle, closed=False):
+ """Constructor for a consistent arc geometry.
+
+ There is also specific class method to create different kind of arc
+ geometry.
+ """
+ self.center = center
+ self.startPoint = startPoint
+ self.endPoint = endPoint
+ self.radius = radius
+ self.weight = weight
+ self.startAngle = startAngle
+ self.endAngle = endAngle
+ self._closed = closed
+
+ @classmethod
+ def createEmpty(cls):
+ """Create an arc geometry from an empty shape
+ """
+ zero = numpy.array([0, 0])
+ return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0)
+
+ @classmethod
+ def createRect(cls, startPoint, endPoint, weight):
+ """Create an arc geometry from a definition of a rectangle
+ """
+ return cls(None, startPoint, endPoint, None, weight, None, None, False)
+
+ @classmethod
+ def createCircle(cls, center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle):
+ """Create an arc geometry from a definition of a circle
+ """
+ return cls(center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle, True)
+
+ def withWeight(self, weight):
+ """Return a new geometry based on this object, with a specific weight
+ """
+ return _ArcGeometry(self.center, self.startPoint, self.endPoint,
+ self.radius, weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def withRadius(self, radius):
+ """Return a new geometry based on this object, with a specific radius.
+
+ The weight and the center is conserved.
+ """
+ startPoint = self.center + (self.startPoint - self.center) / self.radius * radius
+ endPoint = self.center + (self.endPoint - self.center) / self.radius * radius
+ return _ArcGeometry(self.center, startPoint, endPoint,
+ radius, self.weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def withStartAngle(self, startAngle):
+ """Return a new geometry based on this object, with a specific start angle
+ """
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = startAngle - self.startAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ startAngle = self.startAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def withEndAngle(self, endAngle):
+ """Return a new geometry based on this object, with a specific end angle
+ """
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = endAngle - self.endAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ endAngle = self.endAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ self.startPoint,
+ endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ endAngle,
+ self._closed,
+ )
+
+ def translated(self, dx, dy):
+ """Return the translated geometry by dx, dy"""
+ delta = numpy.array([dx, dy])
+ center = None if self.center is None else self.center + delta
+ startPoint = None if self.startPoint is None else self.startPoint + delta
+ endPoint = None if self.endPoint is None else self.endPoint + delta
+ return _ArcGeometry(center, startPoint, endPoint,
+ self.radius, self.weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def getKind(self):
+ """Returns the kind of shape defined"""
+ if self.center is None:
+ return "rect"
+ elif numpy.isnan(self.startAngle):
+ return "point"
+ elif self.isClosed():
+ if self.weight <= 0 or self.weight * 0.5 >= self.radius:
+ return "circle"
+ else:
+ return "donut"
+ else:
+ if self.weight * 0.5 < self.radius:
+ return "arc"
+ else:
+ return "camembert"
+
+ def isClosed(self):
+ """Returns True if the geometry is a circle like"""
+ if self._closed is not None:
+ return self._closed
+ delta = numpy.abs(self.endAngle - self.startAngle)
+ self._closed = numpy.isclose(delta, numpy.pi * 2)
+ return self._closed
+
+ def __str__(self):
+ return str((self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed))
+
+
+class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
+ """A ROI identifying an arc of a circle with a width.
+
+ This ROI provides
+ - 3 handle to control the curvature
+ - 1 handle to control the weight
+ - 1 anchor to translate the shape.
+ """
+
+ ICON = 'add-shape-arc'
+ NAME = 'arc ROI'
+ SHORT_NAME = "arc"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ ThreePointMode = RoiInteractionMode("3 points", "Provides 3 points to define the main radius circle")
+ PolarMode = RoiInteractionMode("Polar", "Provides anchors to edit the ROI in polar coords")
+ # FIXME: MoveMode was designed cause there is too much anchors
+ # FIXME: It would be good replace it by a dnd on the shape
+ MoveMode = RoiInteractionMode("Translation", "Provides anchors to only move the ROI")
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ InteractionModeMixIn.__init__(self)
+
+ self._geometry = _ArcGeometry.createEmpty()
+ self._handleLabel = self.addLabelHandle()
+
+ self._handleStart = self.addHandle()
+ self._handleMid = self.addHandle()
+ self._handleEnd = self.addHandle()
+ self._handleWeight = self.addHandle()
+ self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint)
+ self._handleMove = self.addTranslateHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self._initInteractionMode(self.ThreePointMode)
+ self._interactiveModeUpdated(self.ThreePointMode)
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ :rtype: List[RoiInteractionMode]
+ """
+ return [self.ThreePointMode, self.PolarMode, self.MoveMode]
+
+ def _interactiveModeUpdated(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId:
+ """
+ if modeId is self.ThreePointMode:
+ self._handleStart.setSymbol("s")
+ self._handleMid.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.PolarMode:
+ self._handleStart.setSymbol("o")
+ self._handleMid.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleMid.setSymbol("+")
+ self._handleEnd.setSymbol("")
+ self._handleWeight.setSymbol("")
+ self._handleMove.setSymbol("+")
+ else:
+ assert False
+ if self._geometry.isClosed():
+ if modeId != self.MoveMode:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ self._updateHandles()
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(ArcROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(ArcROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ """"Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ # The first shape is a line
+ point0 = points[0]
+ point1 = points[1]
+
+ # Compute a non collinear point for the curvature
+ center = (point1 + point0) * 0.5
+ normal = point1 - center
+ normal = numpy.array((normal[1], -normal[0]))
+ defaultCurvature = numpy.pi / 5.0
+ weightCoef = 0.20
+ mid = center - normal * defaultCurvature
+ distance = numpy.linalg.norm(point0 - point1)
+ weight = distance * weightCoef
+
+ geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight)
+ self._geometry = geometry
+ self._updateHandles()
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def _updateMidHandle(self):
+ """Keep the same geometry, but update the location of the control
+ points.
+
+ So calling this function do not trigger sigRegionChanged.
+ """
+ geometry = self._geometry
+
+ if geometry.isClosed():
+ start = numpy.array(self._handleStart.getPosition())
+ midPos = geometry.center + geometry.center - start
+ else:
+ if geometry.center is None:
+ midPos = geometry.startPoint * 0.5 + geometry.endPoint * 0.5
+ else:
+ midAngle = geometry.startAngle * 0.5 + geometry.endAngle * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ midPos = geometry.center + geometry.radius * vector
+
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*midPos)
+
+ def _updateWeightHandle(self):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal = normal / distance
+ weightPos = center + normal * geometry.weight * 0.5
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ elif geometry.center is not None:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
+
+ with utils.blockSignals(self._handleWeight):
+ self._handleWeight.setPosition(*weightPos)
+
+ def _getWeightFromHandle(self, weightPos):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ return numpy.linalg.norm(center - weightPos) * 2
+ else:
+ distance = numpy.linalg.norm(geometry.center - weightPos)
+ return abs(distance - geometry.radius) * 2
+
+ def _updateHandles(self):
+ geometry = self._geometry
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*geometry.startPoint)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*geometry.endPoint)
+
+ self._updateMidHandle()
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False):
+ """Update the curvature using 3 control points in the curve
+
+ :param bool updateCurveHandles: If False curve handles are already at
+ the right location
+ """
+ if checkClosed:
+ closed = self._isCloseInPixel(start, end)
+ else:
+ closed = self._geometry.isClosed()
+ if closed:
+ if updateStart:
+ start = end
+ else:
+ end = start
+
+ if updateCurveHandles:
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*start)
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*mid)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*end)
+
+ weight = self._geometry.weight
+ geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed)
+ self._geometry = geometry
+
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCloseInAngle(self, geometry, updateStart):
+ azim = numpy.abs(geometry.endAngle - geometry.startAngle)
+ if numpy.pi < azim < 3 * numpy.pi:
+ closed = self._isCloseInPixel(geometry.startPoint, geometry.endPoint)
+ geometry._closed = closed
+ if closed:
+ sign = 1 if geometry.startAngle < geometry.endAngle else -1
+ if updateStart:
+ geometry.startPoint = geometry.endPoint
+ geometry.startAngle = geometry.endAngle - sign * 2*numpy.pi
+ else:
+ geometry.endPoint = geometry.startPoint
+ geometry.endAngle = geometry.startAngle + sign * 2*numpy.pi
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ modeId = self.getInteractionMode()
+ if handle is self._handleStart:
+ if modeId is self.ThreePointMode:
+ mid = numpy.array(self._handleMid.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(
+ current, mid, end, checkClosed=True, updateStart=True,
+ updateCurveHandles=False
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withStartAngle(startAngle)
+ self._updateCloseInAngle(geometry, updateStart=True)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleMid:
+ if modeId is self.ThreePointMode:
+ if self._geometry.isClosed():
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ else:
+ start = numpy.array(self._handleStart.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(start, current, end, updateCurveHandles=False)
+ elif modeId is self.PolarMode:
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ elif modeId is self.MoveMode:
+ delta = current - previous
+ self.translate(*delta)
+ elif handle is self._handleEnd:
+ if modeId is self.ThreePointMode:
+ start = numpy.array(self._handleStart.getPosition())
+ mid = numpy.array(self._handleMid.getPosition())
+ self._updateCurvature(
+ start, mid, current, checkClosed=True, updateStart=False,
+ updateCurveHandles=False
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withEndAngle(endAngle)
+ self._updateCloseInAngle(geometry, updateStart=False)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleWeight:
+ weight = self._getWeightFromHandle(current)
+ self._geometry = self._geometry.withWeight(weight)
+ self._updateShape()
+ elif handle is self._handleMove:
+ delta = current - previous
+ self.translate(*delta)
+
+ def _isCloseInPixel(self, point1, point2):
+ manager = self.parent()
+ if manager is None:
+ return False
+ plot = manager.parent()
+ if plot is None:
+ return False
+ point1 = plot.dataToPixel(*point1)
+ if point1 is None:
+ return False
+ point2 = plot.dataToPixel(*point2)
+ if point2 is None:
+ return False
+ return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15
+
+ def _normalizeGeometry(self):
+ """Keep the same phisical geometry, but with normalized parameters.
+ """
+ geometry = self._geometry
+ if geometry.weight * 0.5 >= geometry.radius:
+ radius = (geometry.weight * 0.5 + geometry.radius) * 0.5
+ geometry = geometry.withRadius(radius)
+ geometry = geometry.withWeight(radius * 2)
+ self._geometry = geometry
+ return True
+ return False
+
+ def handleDragFinished(self, handle, origin, current):
+ modeId = self.getInteractionMode()
+ if handle in [self._handleStart, self._handleMid, self._handleEnd]:
+ if modeId is self.ThreePointMode:
+ self._normalizeGeometry()
+ self._updateHandles()
+
+ if self._geometry.isClosed():
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+ else:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ else:
+ if modeId is self.ThreePointMode:
+ self._handleStart.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ elif modeId is self.PolarMode:
+ self._handleStart.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+
+ def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None):
+ """Returns the geometry of the object"""
+ if closed or (closed is None and numpy.allclose(start, end)):
+ # Special arc: It's a closed circle
+ center = (start + mid) * 0.5
+ radius = numpy.linalg.norm(start - center)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ endAngle = startAngle + numpy.pi * 2.0
+ return _ArcGeometry.createCircle(
+ center, start, end, radius, weight, startAngle, endAngle
+ )
+
+ elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5:
+ # Degenerated arc, it's a rectangle
+ return _ArcGeometry.createRect(start, end, weight)
+ else:
+ center, radius = self._circleEquation(start, mid, end)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ v = mid - center
+ midAngle = numpy.angle(complex(v[0], v[1]))
+ v = end - center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+
+ # Is it clockwise or anticlockwise
+ relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ if relativeMid < relativeEnd:
+ if endAngle < startAngle:
+ endAngle += 2 * numpy.pi
+ else:
+ if endAngle > startAngle:
+ endAngle -= 2 * numpy.pi
+
+ return _ArcGeometry(center, start, end,
+ radius, weight, startAngle, endAngle)
+
+ def _createShapeFromGeometry(self, geometry):
+ kind = geometry.getKind()
+ if kind == "rect":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal /= distance
+ points = numpy.array([
+ geometry.startPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint - normal * geometry.weight * 0.5,
+ geometry.startPoint - normal * geometry.weight * 0.5])
+ elif kind == "point":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ # NOTE: At least 2 points are expected
+ points = numpy.array([geometry.startPoint, geometry.startPoint])
+ elif kind == "circle":
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a circle
+ points = []
+ numpy.append(angles, angles[-1])
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points = numpy.array(points)
+ elif kind == "donut":
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a donut
+ points = []
+ # NOTE: NaN value allow to create 2 separated circle shapes
+ # using a single plot item. It's a kind of cheat
+ points.append(numpy.array([float("nan"), float("nan")]))
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.append(numpy.array([float("nan"), float("nan")]))
+ points = numpy.array(points)
+ else:
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+
+ delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
+ if geometry.startAngle == geometry.endAngle:
+ # Degenerated, it's a line (single radius)
+ angle = geometry.startAngle
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points = []
+ points.append(geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ return numpy.array(points)
+
+ angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta)
+ if angles[-1] != geometry.endAngle:
+ angles = numpy.append(angles, geometry.endAngle)
+
+ if kind == "camembert":
+ # It's a part of camembert
+ points = []
+ points.append(geometry.center)
+ points.append(geometry.startPoint)
+ delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points.append(geometry.endPoint)
+ points.append(geometry.center)
+ elif kind == "arc":
+ # It's a part of donut
+ points = []
+ points.append(geometry.startPoint)
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.insert(0, geometry.endPoint)
+ points.append(geometry.endPoint)
+ else:
+ assert False
+
+ points = numpy.array(points)
+
+ return points
+
+ def _updateShape(self):
+ geometry = self._geometry
+ points = self._createShapeFromGeometry(geometry)
+ self.__shape.setPoints(points)
+
+ index = numpy.nanargmin(points[:, 1])
+ pos = points[index]
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(pos[0], pos[1])
+
+ if geometry.center is None:
+ movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66
+ else:
+ movePos = geometry.center
+
+ with utils.blockSignals(self._handleMove):
+ self._handleMove.setPosition(*movePos)
+
+ self.sigRegionChanged.emit()
+
+ def getGeometry(self):
+ """Returns a tuple containing the geometry of this ROI
+
+ It is a symmetric function of :meth:`setGeometry`.
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: Tuple[numpy.ndarray,float,float,float,float]
+ :raise ValueError: In case the ROI can't be represented as section of
+ a circle
+ """
+ geometry = self._geometry
+ if geometry.center is None:
+ raise ValueError("This ROI can't be represented as a section of circle")
+ return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle
+
+ def isClosed(self):
+ """Returns true if the arc is a closed shape, like a circle or a donut.
+
+ :rtype: bool
+ """
+ return self._geometry.isClosed()
+
+ def getCenter(self):
+ """Returns the center of the circle used to draw arcs of this ROI.
+
+ This center is usually outside the the shape itself.
+
+ :rtype: numpy.ndarray
+ """
+ return self._geometry.center
+
+ def getStartAngle(self):
+ """Returns the angle of the start of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.startAngle
+
+ def getEndAngle(self):
+ """Returns the angle of the end of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.endAngle
+
+ def getInnerRadius(self):
+ """Returns the radius of the smaller arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius - geometry.weight * 0.5
+ if radius < 0:
+ radius = 0
+ return radius
+
+ def getOuterRadius(self):
+ """Returns the radius of the bigger arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius + geometry.weight * 0.5
+ return radius
+
+ def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle):
+ """
+ Set the geometry of this arc.
+
+ :param numpy.ndarray center: Center of the circle.
+ :param float innerRadius: Radius of the smaller arc of the section.
+ :param float outerRadius: Weight of the bigger arc of the section.
+ It have to be bigger than `innerRadius`
+ :param float startAngle: Location of the start of the section (in radian)
+ :param float endAngle: Location of the end of the section (in radian).
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+ """
+ assert innerRadius <= outerRadius
+ assert numpy.abs(startAngle - endAngle) <= 2 * numpy.pi
+ center = numpy.array(center)
+ radius = (innerRadius + outerRadius) * 0.5
+ weight = outerRadius - innerRadius
+
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = center + vector * radius
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = center + vector * radius
+
+ geometry = _ArcGeometry(center, startPoint, endPoint,
+ radius, weight,
+ startAngle, endAngle, closed=None)
+ self._geometry = geometry
+ self._updateHandles()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ # first check distance, fastest
+ center = self.getCenter()
+ distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2)
+ is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius()
+ if not is_in_distance:
+ return False
+ rel_pos = position[1] - center[1], position[0] - center[0]
+ angle = numpy.arctan2(*rel_pos)
+ # angle is inside [-pi, pi]
+
+ # Normalize the start angle between [-pi, pi]
+ # with a positive angle range
+ start_angle = self.getStartAngle()
+ end_angle = self.getEndAngle()
+ azim_range = end_angle - start_angle
+ if azim_range < 0:
+ start_angle = end_angle
+ azim_range = -azim_range
+ start_angle = numpy.mod(start_angle + numpy.pi, 2 * numpy.pi) - numpy.pi
+
+ if angle < start_angle:
+ angle += 2 * numpy.pi
+ return start_angle <= angle <= start_angle + azim_range
+
+ def translate(self, x, y):
+ self._geometry = self._geometry.translated(x, y)
+ self._updateHandles()
+
+ def _arcCurvatureMarkerConstraint(self, x, y):
+ """Curvature marker remains on perpendicular bisector"""
+ geometry = self._geometry
+ if geometry.center is None:
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ vector = geometry.startPoint - geometry.endPoint
+ vector = numpy.array((vector[1], -vector[0]))
+ vdist = numpy.linalg.norm(vector)
+ if vdist != 0:
+ normal = numpy.array((vector[1], -vector[0])) / vdist
+ else:
+ normal = numpy.array((0, 0))
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ else:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ center = geometry.center
+ dist = numpy.dot(normal, (numpy.array((x, y)) - center))
+ dist = numpy.clip(dist, geometry.radius, geometry.radius * 2)
+ x, y = center + dist * normal
+ return x, y
+
+ @staticmethod
+ def _circleEquation(pt1, pt2, pt3):
+ """Circle equation from 3 (x, y) points
+
+ :return: Position of the center of the circle and the radius
+ :rtype: Tuple[Tuple[float,float],float]
+ """
+ x, y, z = complex(*pt1), complex(*pt2), complex(*pt3)
+ w = z - x
+ w /= y - x
+ c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x
+ return numpy.array((-c.real, -c.imag)), abs(c + x)
+
+ def __str__(self):
+ try:
+ center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry()
+ params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle
+ params = 'center: %f %f; radius: %f %f; angles: %f %f' % params
+ except ValueError:
+ params = "invalid"
+ return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/silx/gui/plot/items/_pick.py b/silx/gui/plot/items/_pick.py
index 4ddf4f6..8c8e781 100644
--- a/silx/gui/plot/items/_pick.py
+++ b/silx/gui/plot/items/_pick.py
@@ -48,7 +48,7 @@ class PickingResult(object):
self._indices = None
else:
# Indices is set to None if indices array is empty
- indices = numpy.array(indices, copy=False, dtype=numpy.int)
+ indices = numpy.array(indices, copy=False, dtype=numpy.int64)
self._indices = None if indices.size == 0 else indices
def getItem(self):
diff --git a/silx/gui/plot/items/_roi_base.py b/silx/gui/plot/items/_roi_base.py
new file mode 100644
index 0000000..3eb6cf4
--- /dev/null
+++ b/silx/gui/plot/items/_roi_base.py
@@ -0,0 +1,835 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides base components to create ROI item for
+the :class:`~silx.gui.plot.PlotWidget`.
+
+.. inheritance-diagram::
+ silx.gui.plot.items.roi
+ :parts: 1
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import numpy
+import weakref
+
+from ....utils.weakref import WeakList
+from ... import qt
+from .. import items
+from ..items import core
+from ...colors import rgba
+import silx.utils.deprecation
+from ....utils.proxy import docstring
+
+
+logger = logging.getLogger(__name__)
+
+
+class _RegionOfInterestBase(qt.QObject):
+ """Base class of 1D and 2D region of interest
+
+ :param QObject parent: See QObject
+ :param str name: The name of the ROI
+ """
+
+ sigAboutToBeRemoved = qt.Signal()
+ """Signal emitted just before this ROI is removed from its manager."""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ def __init__(self, parent=None):
+ qt.QObject.__init__(self, parent=parent)
+ self.__name = ''
+
+ def getName(self):
+ """Returns the name of the ROI
+
+ :return: name of the region of interest
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the ROI
+
+ :param str name: name of the region of interest
+ """
+ name = str(name)
+ if self.__name != name:
+ self.__name = name
+ self._updated(items.ItemChangedType.NAME)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ self.sigItemChanged.emit(event)
+
+ def contains(self, position):
+ """Returns True if the `position` is in this ROI.
+
+ :param tuple[float,float] position: position to check
+ :return: True if the value / point is consider to be in the region of
+ interest.
+ :rtype: bool
+ """
+ return False # Override in subclass to perform actual test
+
+
+class RoiInteractionMode(object):
+ """Description of an interaction mode.
+
+ An interaction mode provide a specific kind of interaction for a ROI.
+ A ROI can implement many interaction.
+ """
+
+ def __init__(self, label, description=None):
+ self._label = label
+ self._description = description
+
+ @property
+ def label(self):
+ return self._label
+
+ @property
+ def description(self):
+ return self._description
+
+
+class InteractionModeMixIn(object):
+ """Mix in feature which can be implemented by a ROI object.
+
+ This provides user interaction to switch between different
+ interaction mode to edit the ROI.
+
+ This ROI modes have to be described using `RoiInteractionMode`,
+ and taken into account during interation with handles.
+ """
+
+ sigInteractionModeChanged = qt.Signal(object)
+
+ def __init__(self):
+ self.__modeId = None
+
+ def _initInteractionMode(self, modeId):
+ """Set the mode without updating anything.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ Must be implemented when inherited to provide all available modes.
+
+ :rtype: List[RoiInteractionMode]
+ """
+ raise NotImplementedError()
+
+ def setInteractionMode(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+ self._interactiveModeUpdated(modeId)
+ self.sigInteractionModeChanged.emit(modeId)
+
+ def _interactiveModeUpdated(self, modeId):
+ """Called directly after an update of the mode.
+
+ The signal `sigInteractionModeChanged` is triggered after this
+ call.
+
+ Must be implemented when inherited to take care of the change.
+ """
+ raise NotImplementedError()
+
+ def getInteractionMode(self):
+ """Returns the interaction mode.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :rtype: RoiInteractionMode
+ """
+ return self.__modeId
+
+
+class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
+ """Object describing a region of interest in a plot.
+
+ :param QObject parent:
+ The RegionOfInterestManager that created this object
+ """
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2)
+ """Default highlight style of the item"""
+
+ ICON, NAME, SHORT_NAME = None, None, None
+ """Metadata to describe the ROI in labels, tooltips and widgets
+
+ Should be set by inherited classes to custom the ROI manager widget.
+ """
+
+ sigRegionChanged = qt.Signal()
+ """Signal emitted everytime the shape or position of the ROI changes"""
+
+ sigEditingStarted = qt.Signal()
+ """Signal emitted when the user start editing the roi"""
+
+ sigEditingFinished = qt.Signal()
+ """Signal emitted when the region edition is finished. During edition
+ sigEditionChanged will be emitted several times and
+ sigRegionEditionFinished only at end"""
+
+ def __init__(self, parent=None):
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+ assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
+ _RegionOfInterestBase.__init__(self, parent)
+ core.HighlightedMixIn.__init__(self)
+ self._color = rgba('red')
+ self._editable = False
+ self._selectable = False
+ self._focusProxy = None
+ self._visible = True
+ self._child = WeakList()
+
+ def _connectToPlot(self, plot):
+ """Called after connection to a plot"""
+ for item in self.getItems():
+ # This hack is needed to avoid reentrant call from _disconnectFromPlot
+ # to the ROI manager. It also speed up the item tests in _itemRemoved
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def _disconnectFromPlot(self, plot):
+ """Called before disconnection from a plot"""
+ for item in self.getItems():
+ # The item could be already be removed by the plot
+ if item.getPlot() is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def _setItemName(self, item):
+ """Helper to generate a unique id to a plot item"""
+ legend = "__ROI-%d__%d" % (id(self), id(item))
+ item.setName(legend)
+
+ def setParent(self, parent):
+ """Set the parent of the RegionOfInterest
+
+ :param Union[None,RegionOfInterestManager] parent: The new parent
+ """
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+ if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)):
+ raise ValueError('Unsupported parent')
+
+ previousParent = self.parent()
+ if previousParent is not None:
+ previousPlot = previousParent.parent()
+ if previousPlot is not None:
+ self._disconnectFromPlot(previousPlot)
+ super(RegionOfInterest, self).setParent(parent)
+ if parent is not None:
+ plot = parent.parent()
+ if plot is not None:
+ self._connectToPlot(plot)
+
+ def addItem(self, item):
+ """Add an item to the set of this ROI children.
+
+ This item will be added and removed to the plot used by the ROI.
+
+ If the ROI is already part of a plot, the item will also be added to
+ the plot.
+
+ It the item do not have a name already, a unique one is generated to
+ avoid item collision in the plot.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.append(item)
+ if item.getName() == '':
+ self._setItemName(item)
+ manager = self.parent()
+ if manager is not None:
+ plot = manager.parent()
+ if plot is not None:
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def removeItem(self, item):
+ """Remove an item from this ROI children.
+
+ If the item is part of a plot it will be removed too.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.remove(item)
+ plot = item.getPlot()
+ if plot is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def getItems(self):
+ """Returns the list of PlotWidget items of this RegionOfInterest.
+
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ return tuple(self._child)
+
+ @classmethod
+ def _getShortName(cls):
+ """Return an human readable kind of ROI
+
+ :rtype: str
+ """
+ if hasattr(cls, "SHORT_NAME"):
+ name = cls.SHORT_NAME
+ if name is None:
+ name = cls.__name__
+ return name
+
+ def getColor(self):
+ """Returns the color of this ROI
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the color used for this ROI.
+
+ :param color: The color to use for ROI shape as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ self._updated(items.ItemChangedType.COLOR)
+
+ @silx.utils.deprecation.deprecated(reason='API modification',
+ replacement='getName()',
+ since_version=0.12)
+ def getLabel(self):
+ """Returns the label displayed for this ROI.
+
+ :rtype: str
+ """
+ return self.getName()
+
+ @silx.utils.deprecation.deprecated(reason='API modification',
+ replacement='setName(name)',
+ since_version=0.12)
+ def setLabel(self, label):
+ """Set the label displayed with this ROI.
+
+ :param str label: The text label to display
+ """
+ self.setName(name=label)
+
+ def isEditable(self):
+ """Returns whether the ROI is editable by the user or not.
+
+ :rtype: bool
+ """
+ return self._editable
+
+ def setEditable(self, editable):
+ """Set whether the ROI can be changed interactively.
+
+ :param bool editable: True to allow edition by the user,
+ False to disable.
+ """
+ editable = bool(editable)
+ if self._editable != editable:
+ self._editable = editable
+ self._updated(items.ItemChangedType.EDITABLE)
+
+ def isSelectable(self):
+ """Returns whether the ROI is selectable by the user or not.
+
+ :rtype: bool
+ """
+ return self._selectable
+
+ def setSelectable(self, selectable):
+ """Set whether the ROI can be selected interactively.
+
+ :param bool selectable: True to allow selection by the user,
+ False to disable.
+ """
+ selectable = bool(selectable)
+ if self._selectable != selectable:
+ self._selectable = selectable
+ self._updated(items.ItemChangedType.SELECTABLE)
+
+ def getFocusProxy(self):
+ """Returns the ROI which have to be selected when this ROI is selected,
+ else None if no proxy specified.
+
+ :rtype: RegionOfInterest
+ """
+ proxy = self._focusProxy
+ if proxy is None:
+ return None
+ proxy = proxy()
+ if proxy is None:
+ self._focusProxy = None
+ return proxy
+
+ def setFocusProxy(self, roi):
+ """Set the real ROI which will be selected when this ROI is selected,
+ else None to remove the proxy already specified.
+
+ :param RegionOfInterest roi: A ROI
+ """
+ if roi is not None:
+ self._focusProxy = weakref.ref(roi)
+ else:
+ self._focusProxy = None
+
+ def isVisible(self):
+ """Returns whether the ROI is visible in the plot.
+
+ .. note::
+ This does not take into account whether or not the plot
+ widget itself is visible (unlike :meth:`QWidget.isVisible` which
+ checks the visibility of all its parent widgets up to the window)
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set whether the plot items associated with this ROI are
+ visible in the plot.
+
+ :param bool visible: True to show the ROI in the plot, False to
+ hide it.
+ """
+ visible = bool(visible)
+ if self._visible != visible:
+ self._visible = visible
+ self._updated(items.ItemChangedType.VISIBLE)
+
+ @classmethod
+ def showFirstInteractionShape(cls):
+ """Returns True if the shape created by the first interaction and
+ managed by the plot have to be visible.
+
+ :rtype: bool
+ """
+ return False
+
+ @classmethod
+ def getFirstInteractionShape(cls):
+ """Returns the shape kind which will be used by the very first
+ interaction with the plot.
+
+ This interactions are hardcoded inside the plot
+
+ :rtype: str
+ """
+ return cls._plotShape
+
+ def setFirstShapePoints(self, points):
+ """"Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ raise NotImplementedError()
+
+ def creationStarted(self):
+ """"Called when the ROI creation interaction was started.
+ """
+ pass
+
+ def creationFinalized(self):
+ """"Called when the ROI creation interaction was finalized.
+ """
+ pass
+
+ def _updateItemProperty(self, event, source, destination):
+ """Update the item property of a destination from an item source.
+
+ :param items.ItemChangedType event: Property type to update
+ :param silx.gui.plot.items.Item source: The reference for the data
+ :param event Union[Item,List[Item]] destination: The item(s) to update
+ """
+ if not isinstance(destination, (list, tuple)):
+ destination = [destination]
+ if event == items.ItemChangedType.NAME:
+ value = source.getName()
+ for d in destination:
+ d.setName(value)
+ elif event == items.ItemChangedType.EDITABLE:
+ value = source.isEditable()
+ for d in destination:
+ d.setEditable(value)
+ elif event == items.ItemChangedType.SELECTABLE:
+ value = source.isSelectable()
+ for d in destination:
+ d._setSelectable(value)
+ elif event == items.ItemChangedType.COLOR:
+ value = rgba(source.getColor())
+ for d in destination:
+ d.setColor(value)
+ elif event == items.ItemChangedType.LINE_STYLE:
+ value = self.getLineStyle()
+ for d in destination:
+ d.setLineStyle(value)
+ elif event == items.ItemChangedType.LINE_WIDTH:
+ value = self.getLineWidth()
+ for d in destination:
+ d.setLineWidth(value)
+ elif event == items.ItemChangedType.SYMBOL:
+ value = self.getSymbol()
+ for d in destination:
+ d.setSymbol(value)
+ elif event == items.ItemChangedType.SYMBOL_SIZE:
+ value = self.getSymbolSize()
+ for d in destination:
+ d.setSymbolSize(value)
+ elif event == items.ItemChangedType.VISIBLE:
+ value = self.isVisible()
+ for d in destination:
+ d.setVisible(value)
+ else:
+ assert False
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.HIGHLIGHTED:
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+ else:
+ styleEvents = [items.ItemChangedType.COLOR,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE]
+ if self.isHighlighted():
+ styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE)
+
+ if event in styleEvents:
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+
+ super(RegionOfInterest, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ """Called when the current displayed style of the ROI was changed.
+
+ :param event: The event responsible of the change of the style
+ :param items.CurveStyle style: The current style
+ """
+ pass
+
+ def getCurrentStyle(self):
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ baseColor = rgba(self.getColor())
+ if isinstance(self, core.LineMixIn):
+ baseLinestyle = self.getLineStyle()
+ baseLinewidth = self.getLineWidth()
+ else:
+ baseLinestyle = self._DEFAULT_LINESTYLE
+ baseLinewidth = self._DEFAULT_LINEWIDTH
+ if isinstance(self, core.SymbolMixIn):
+ baseSymbol = self.getSymbol()
+ baseSymbolsize = self.getSymbolSize()
+ else:
+ baseSymbol = 'o'
+ baseSymbolsize = 1
+
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+
+ return items.CurveStyle(
+ color=baseColor if color is None else color,
+ linestyle=baseLinestyle if linestyle is None else linestyle,
+ linewidth=baseLinewidth if linewidth is None else linewidth,
+ symbol=baseSymbol if symbol is None else symbol,
+ symbolsize=baseSymbolsize if symbolsize is None else symbolsize)
+ else:
+ return items.CurveStyle(color=baseColor,
+ linestyle=baseLinestyle,
+ linewidth=baseLinewidth,
+ symbol=baseSymbol,
+ symbolsize=baseSymbolsize)
+
+ def _editingStarted(self):
+ assert self._editable is True
+ self.sigEditingStarted.emit()
+
+ def _editingFinished(self):
+ self.sigEditingFinished.emit()
+
+
+class HandleBasedROI(RegionOfInterest):
+ """Manage a ROI based on a set of handles"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ self._handles = []
+ self._posOrigin = None
+ self._posPrevious = None
+
+ def addUserHandle(self, item=None):
+ """
+ Add a new free handle to the ROI.
+
+ This handle do nothing. It have to be managed by the ROI
+ implementing this class.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="user")
+
+ def addLabelHandle(self, item=None):
+ """
+ Add a new label handle to the ROI.
+
+ This handle is not draggable nor selectable.
+
+ It is displayed without symbol, but it is always visible anyway
+ the ROI is editable, in order to display text.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="label")
+
+ def addTranslateHandle(self, item=None):
+ """
+ Add a new translate handle to the ROI.
+
+ Dragging translate handles affect the position position of the ROI
+ but not the shape itself.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="translate")
+
+ def addHandle(self, item=None, role="default"):
+ """
+ Add a new handle to the ROI.
+
+ Dragging handles while affect the position or the shape of the
+ ROI.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ if item is None:
+ item = items.Marker()
+ color = rgba(self.getColor())
+ color = self._computeHandleColor(color)
+ item.setColor(color)
+ if role == "default":
+ item.setSymbol("s")
+ elif role == "user":
+ pass
+ elif role == "translate":
+ item.setSymbol("+")
+ elif role == "label":
+ item.setSymbol("")
+
+ if role == "user":
+ pass
+ elif role == "label":
+ item._setSelectable(False)
+ item._setDraggable(False)
+ item.setVisible(True)
+ else:
+ self.__updateEditable(item, self.isEditable(), remove=False)
+ item._setSelectable(False)
+
+ self._handles.append((item, role))
+ self.addItem(item)
+ return item
+
+ def removeHandle(self, handle):
+ data = [d for d in self._handles if d[0] is handle][0]
+ self._handles.remove(data)
+ role = data[1]
+ if role not in ["user", "label"]:
+ if self.isEditable():
+ self.__updateEditable(handle, False)
+ self.removeItem(handle)
+
+ def getHandles(self):
+ """Returns the list of handles of this HandleBasedROI.
+
+ :rtype: List[~silx.gui.plot.items.Marker]
+ """
+ return tuple(data[0] for data in self._handles)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ if event == items.ItemChangedType.NAME:
+ self._updateText(self.getName())
+ elif event == items.ItemChangedType.VISIBLE:
+ for item, role in self._handles:
+ visible = self.isVisible()
+ editionVisible = visible and self.isEditable()
+ if role not in ["user", "label"]:
+ item.setVisible(editionVisible)
+ else:
+ item.setVisible(visible)
+ elif event == items.ItemChangedType.EDITABLE:
+ for item, role in self._handles:
+ editable = self.isEditable()
+ if role not in ["user", "label"]:
+ self.__updateEditable(item, editable)
+ super(HandleBasedROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(HandleBasedROI, self)._updatedStyle(event, style)
+
+ # Update color of shape items in the plot
+ color = rgba(self.getColor())
+ handleColor = self._computeHandleColor(color)
+ for item, role in self._handles:
+ if role == 'user':
+ pass
+ elif role == 'label':
+ item.setColor(color)
+ else:
+ item.setColor(handleColor)
+
+ def __updateEditable(self, handle, editable, remove=True):
+ # NOTE: visibility change emit a position update event
+ handle.setVisible(editable and self.isVisible())
+ handle._setDraggable(editable)
+ if editable:
+ handle.sigDragStarted.connect(self._handleEditingStarted)
+ handle.sigItemChanged.connect(self._handleEditingUpdated)
+ handle.sigDragFinished.connect(self._handleEditingFinished)
+ else:
+ if remove:
+ handle.sigDragStarted.disconnect(self._handleEditingStarted)
+ handle.sigItemChanged.disconnect(self._handleEditingUpdated)
+ handle.sigDragFinished.disconnect(self._handleEditingFinished)
+
+ def _handleEditingStarted(self):
+ super(HandleBasedROI, self)._editingStarted()
+ handle = self.sender()
+ self._posOrigin = numpy.array(handle.getPosition())
+ self._posPrevious = numpy.array(self._posOrigin)
+ self.handleDragStarted(handle, self._posOrigin)
+
+ def _handleEditingUpdated(self):
+ if self._posOrigin is None:
+ # Avoid to handle events when visibility change
+ return
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current)
+ self._posPrevious = current
+
+ def _handleEditingFinished(self):
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragFinished(handle, self._posOrigin, current)
+ self._posPrevious = None
+ self._posOrigin = None
+ super(HandleBasedROI, self)._editingFinished()
+
+ def isHandleBeingDragged(self):
+ """Returns True if one of the handles is currently being dragged.
+
+ :rtype: bool
+ """
+ return self._posOrigin is not None
+
+ def handleDragStarted(self, handle, origin):
+ """Called when an handler drag started"""
+ pass
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ """Called when an handle drag position changed"""
+ pass
+
+ def handleDragFinished(self, handle, origin, current):
+ """Called when an handle drag finished"""
+ pass
+
+ def _computeHandleColor(self, color):
+ """Returns the anchor color from the base ROI color
+
+ :param Union[numpy.array,Tuple,List]: color
+ :rtype: Union[numpy.array,Tuple,List]
+ """
+ return color[:3] + (0.5,)
+
+ def _updateText(self, text):
+ """Update the text displayed by this ROI
+
+ :param str text: A text
+ """
+ pass
diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py
index 8f0694d..0e492a0 100644
--- a/silx/gui/plot/items/complex.py
+++ b/silx/gui/plot/items/complex.py
@@ -124,10 +124,9 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
"""Overrides supported ComplexMode"""
def __init__(self):
- ImageBase.__init__(self)
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.complex64))
ColormapMixIn.__init__(self)
ComplexMixIn.__init__(self)
- self._data = numpy.zeros((0, 0), dtype=numpy.complex64)
self._dataByModesCache = {}
self._amplitudeRangeInfo = None, 2
@@ -264,17 +263,9 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
'Image is not complex, converting it to complex to plot it.')
data = numpy.array(data, dtype=numpy.complex64)
- self._data = data
self._dataByModesCache = {}
self._setColormappedData(self.getData(copy=False), copy=False)
-
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- self._updated(ItemChangedType.DATA)
+ super().setData(data)
def getComplexData(self, copy=True):
"""Returns the image complex data
@@ -283,7 +274,7 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
False to use internal representation (do not modify!)
:rtype: numpy.ndarray of complex
"""
- return numpy.array(self._data, copy=copy)
+ return super().getData(copy=copy)
def getData(self, copy=True, mode=None):
"""Returns the image data corresponding to (current) mode.
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
index 9426a13..edc6d89 100644
--- a/silx/gui/plot/items/core.py
+++ b/silx/gui/plot/items/core.py
@@ -37,6 +37,7 @@ except ImportError: # Python2 support
from copy import deepcopy
import logging
import enum
+from typing import Optional, Tuple
import warnings
import weakref
@@ -44,7 +45,9 @@ import numpy
import six
from ....utils.deprecation import deprecated
+from ....utils.proxy import docstring
from ....utils.enum import Enum as _Enum
+from ....math.combo import min_max
from ... import qt
from ... import colors
from ...colors import Colormap
@@ -164,6 +167,13 @@ class Item(qt.QObject):
See :class:`ItemChangedType` for flags description.
"""
+ _sigVisibleBoundsChanged = qt.Signal()
+ """Signal emitted when the visible extent of the item in the plot has changed.
+
+ This signal is emitted only if visible extent tracking is enabled
+ (see :meth:`_setVisibleBoundsTracking`).
+ """
+
def __init__(self):
qt.QObject.__init__(self)
self._dirty = True
@@ -176,6 +186,9 @@ class Item(qt.QObject):
self._ylabel = None
self.__name = ''
+ self.__visibleBoundsTracking = False
+ self.__previousVisibleBounds = None
+
self._backendRenderer = None
def getPlot(self):
@@ -194,7 +207,9 @@ class Item(qt.QObject):
"""
if plot is not None and self._plotRef is not None:
raise RuntimeError('Trying to add a node at two places.')
+ self.__disconnectFromPlotWidget()
self._plotRef = None if plot is None else weakref.ref(plot)
+ self.__connectToPlotWidget()
self._updated()
def getBounds(self): # TODO return a Bounds object rather than a tuple
@@ -300,6 +315,97 @@ class Item(qt.QObject):
info = deepcopy(info)
self._info = info
+ def getVisibleBounds(self) -> Optional[Tuple[float,float,float,float]]:
+ """Returns visible bounds of the item bounding box in the plot area.
+
+ :returns:
+ (xmin, xmax, ymin, ymax) in data coordinates of the visible area or
+ None if item is not visible in the plot area.
+ :rtype: Union[List[float],None]
+ """
+ plot = self.getPlot()
+ bounds = self.getBounds()
+ if plot is None or bounds is None or not self.isVisible():
+ return None
+
+ xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
+ ymin, ymax = numpy.clip(
+ bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits())
+
+ if xmin == xmax or ymin == ymax: # Outside the plot area
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ def _isVisibleBoundsTracking(self) -> bool:
+ """Returns True if visible bounds changes are tracked.
+
+ When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes.
+ :rtype: bool
+ """
+ return self.__visibleBoundsTracking
+
+ def _setVisibleBoundsTracking(self, enable: bool) -> None:
+ """Set whether or not to track visible bounds changes.
+
+ :param bool enable:
+ """
+ if enable != self.__visibleBoundsTracking:
+ self.__disconnectFromPlotWidget()
+ self.__previousVisibleBounds = None
+ self.__visibleBoundsTracking = enable
+ self.__connectToPlotWidget()
+
+ def __getYAxis(self) -> str:
+ """Returns current Y axis ('left' or 'right')"""
+ return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left'
+
+ def __connectToPlotWidget(self) -> None:
+ """Connect to PlotWidget signals and install event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.connect(self._visibleBoundsChanged)
+
+ plot.installEventFilter(self)
+
+ self._visibleBoundsChanged()
+
+ def __disconnectFromPlotWidget(self) -> None:
+ """Disconnect from PlotWidget signals and remove event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged)
+
+ plot.removeEventFilter(self)
+
+ def _visibleBoundsChanged(self, *args) -> None:
+ """Check if visible extent actually changed and emit signal"""
+ if not self._isVisibleBoundsTracking():
+ return # No visible extent tracking
+
+ plot = self.getPlot()
+ if plot is None or not plot.isVisible():
+ return # No plot or plot not visible
+
+ extent = self.getVisibleBounds()
+ if extent != self.__previousVisibleBounds:
+ self.__previousVisibleBounds = extent
+ self._sigVisibleBoundsChanged.emit()
+
+ def eventFilter(self, watched, event):
+ """Event filter to handle PlotWidget show events"""
+ if watched is self.getPlot() and event.type() == qt.QEvent.Show:
+ self._visibleBoundsChanged()
+ return super().eventFilter(watched, event)
+
def _updated(self, event=None, checkVisibility=True):
"""Mark the item as dirty (i.e., needing update).
@@ -375,6 +481,29 @@ class Item(qt.QObject):
return PickingResult(self, indices)
+class DataItem(Item):
+ """Item with a data extent in the plot"""
+
+ def _boundsChanged(self, checkVisibility: bool=True) -> None:
+ """Call this method in subclass when data bounds has changed.
+
+ :param bool checkVisibility:
+ """
+ if not checkVisibility or self.isVisible():
+ self._visibleBoundsChanged()
+
+ # TODO hackish data range implementation
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ @docstring(Item)
+ def setVisible(self, visible: bool):
+ if visible != self.isVisible():
+ self._boundsChanged(checkVisibility=False)
+ super().setVisible(visible)
+
+
# Mix-in classes ##############################################################
class ItemMixInBase(object):
@@ -836,6 +965,22 @@ class YAxisMixIn(ItemMixInBase):
assert yaxis in ('left', 'right')
if yaxis != self._yaxis:
self._yaxis = yaxis
+ # Handle data extent changed for DataItem
+ if isinstance(self, DataItem):
+ self._boundsChanged()
+
+ # Handle visible extent changed
+ if self._isVisibleBoundsTracking():
+ # Switch Y axis signal connection
+ plot = self.getPlot()
+ if plot is not None:
+ previousYAxis = 'left' if self.getXAxis() == 'right' else 'right'
+ plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
+ self._visibleBoundsChanged)
+ plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
+ self._visibleBoundsChanged)
+ self._visibleBoundsChanged()
+
self._updated(ItemChangedType.YAXIS)
@@ -1066,6 +1211,16 @@ class ScatterVisualizationMixIn(ItemMixInBase):
Available reduction functions are: 'mean' (default), 'count', 'sum'.
"""
+ DATA_BOUNDS_HINT = 'data_bounds_hint'
+ """The expected bounds of the data in data coordinates.
+
+ A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
+ This provides a hint for the data ranges in both dimensions.
+ It is eventually enlarged with actually data ranges.
+
+ WARNING: dimension 0 i.e., Y first.
+ """
+
_SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'),
VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'),
@@ -1191,7 +1346,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
return self.getVisualizationParameter(parameter)
-class PointsBase(Item, SymbolMixIn, AlphaMixIn):
+class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
"""Base class for :class:`Curve` and :class:`Scatter`"""
# note: _logFilterData must be overloaded if you overload
# getData to change its signature
@@ -1201,7 +1356,7 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn):
on top of images."""
def __init__(self):
- Item.__init__(self)
+ DataItem.__init__(self)
SymbolMixIn.__init__(self)
AlphaMixIn.__init__(self)
self._x = ()
@@ -1244,18 +1399,18 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn):
# expand errorbars to 2xN
if error.size == 1: # Scalar
error = numpy.full(
- (2, len(value)), error, dtype=numpy.float)
+ (2, len(value)), error, dtype=numpy.float64)
elif error.ndim == 1: # N array
newError = numpy.empty((2, len(value)),
- dtype=numpy.float)
+ dtype=numpy.float64)
newError[0, :] = error
newError[1, :] = error
error = newError
elif error.size == 2 * len(value): # 2xN array
error = numpy.array(
- error, copy=True, dtype=numpy.float)
+ error, copy=True, dtype=numpy.float64)
else:
_logger.error("Unhandled error array")
@@ -1309,9 +1464,9 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn):
if numpy.any(clipped):
# copy to keep original array and convert to float
- x = numpy.array(x, copy=True, dtype=numpy.float)
+ x = numpy.array(x, copy=True, dtype=numpy.float64)
x[clipped] = numpy.nan
- y = numpy.array(y, copy=True, dtype=numpy.float)
+ y = numpy.array(y, copy=True, dtype=numpy.float64)
y[clipped] = numpy.nan
if xPositive and xerror is not None:
@@ -1347,15 +1502,11 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn):
else:
x, y, _xerror, _yerror = data
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore All-NaN slice encountered
- self._boundsCache[(xPositive, yPositive)] = (
- numpy.nanmin(x),
- numpy.nanmax(x),
- numpy.nanmin(y),
- numpy.nanmax(y)
- )
+ xmin, xmax = min_max(x, finite=True)
+ ymin, ymax = min_max(y, finite=True)
+ self._boundsCache[(xPositive, yPositive)] = tuple([
+ (bound if bound is not None else numpy.nan)
+ for bound in (xmin, xmax, ymin, ymax)])
return self._boundsCache[(xPositive, yPositive)]
def _getCachedData(self):
@@ -1477,11 +1628,7 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn):
self._filteredCache = {} # Reset cached filtered data
self._clippedCache = {} # Reset cached clipped bool array
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
+ self._boundsChanged()
self._updated(ItemChangedType.DATA)
diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py
index 7922fa1..75e7f01 100644
--- a/silx/gui/plot/items/curve.py
+++ b/silx/gui/plot/items/curve.py
@@ -185,15 +185,6 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
self._setBaseline(Curve._DEFAULT_BASELINE)
- self.sigItemChanged.connect(self.__itemChanged)
-
- def __itemChanged(self, event):
- if event == ItemChangedType.YAXIS:
- # TODO hackish data range implementation
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
# Filter-out values <= 0
@@ -251,20 +242,6 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
else:
raise IndexError("Index out of range: %s", str(item))
- def setVisible(self, visible):
- """Set visibility of item.
-
- :param bool visible: True to display it, False otherwise
- """
- visible = bool(visible)
- # TODO hackish data range implementation
- if self.isVisible() != visible:
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- super(Curve, self).setVisible(visible)
-
@deprecated(replacement='Curve.getHighlightedStyle().getColor()',
since_version='0.9.0')
def getHighlightedColor(self):
diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py
index 935f8d5..5941cc6 100644
--- a/silx/gui/plot/items/histogram.py
+++ b/silx/gui/plot/items/histogram.py
@@ -38,7 +38,7 @@ try:
except ImportError: # Python2 support
import collections as abc
-from .core import (Item, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn,
+from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn,
LineMixIn, YAxisMixIn, ItemChangedType)
_logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ def _getHistogramCurve(histogram, edges):
# TODO: Yerror, test log scale
-class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
+class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
LineMixIn, YAxisMixIn, BaselineMixIn):
"""Description of an histogram"""
@@ -119,7 +119,7 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
_DEFAULT_BASELINE = None
def __init__(self):
- Item.__init__(self)
+ DataItem.__init__(self)
AlphaMixIn.__init__(self)
BaselineMixIn.__init__(self)
ColorMixIn.__init__(self)
@@ -157,8 +157,8 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
(x <= 0) if xPositive else False,
(y <= 0) if yPositive else False)
# Make a copy and replace negative points by NaN
- x = numpy.array(x, dtype=numpy.float)
- y = numpy.array(y, dtype=numpy.float)
+ x = numpy.array(x, dtype=numpy.float64)
+ y = numpy.array(y, dtype=numpy.float64)
x[clipped] = numpy.nan
y[clipped] = numpy.nan
@@ -187,17 +187,17 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
yPositive = False
if xPositive or yPositive:
- values = numpy.array(values, copy=True, dtype=numpy.float)
+ values = numpy.array(values, copy=True, dtype=numpy.float64)
if xPositive:
# Replace edges <= 0 by NaN and corresponding values by NaN
clipped_edges = (edges <= 0)
- edges = numpy.array(edges, copy=True, dtype=numpy.float)
+ edges = numpy.array(edges, copy=True, dtype=numpy.float64)
edges[clipped_edges] = numpy.nan
clipped_values = numpy.logical_or(clipped_edges[:-1],
clipped_edges[1:])
else:
- clipped_values = numpy.zeros_like(values, dtype=numpy.bool)
+ clipped_values = numpy.zeros_like(values, dtype=bool)
if yPositive:
# Replace values <= 0 by NaN, do not modify edges
@@ -219,19 +219,6 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
min(0, numpy.nanmin(values)),
max(0, numpy.nanmax(values)))
- def setVisible(self, visible):
- """Set visibility of item.
-
- :param bool visible: True to display it, False otherwise
- """
- visible = bool(visible)
- # TODO hackish data range implementation
- if self.isVisible() != visible:
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
- super(Histogram, self).setVisible(visible)
-
def getValueData(self, copy=True):
"""The values of the histogram
@@ -314,11 +301,7 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
self._alignement = align
self._setBaseline(baseline)
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
+ self._boundsChanged()
self._updated(ItemChangedType.DATA)
def getAlignment(self):
diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py
index 91c051d..fda4245 100644
--- a/silx/gui/plot/items/image.py
+++ b/silx/gui/plot/items/image.py
@@ -40,7 +40,7 @@ import logging
import numpy
from ....utils.proxy import docstring
-from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn,
+from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
AlphaMixIn, ItemChangedType)
@@ -87,15 +87,20 @@ def _convertImageToRgba32(image, copy=True):
return numpy.array(image, copy=copy)
-class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
- """Description of an image"""
+class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
+ """Description of an image
- def __init__(self):
- Item.__init__(self)
+ :param numpy.ndarray data: Initial image data
+ """
+
+ def __init__(self, data=None):
+ DataItem.__init__(self)
LabelsMixIn.__init__(self)
DraggableMixIn.__init__(self)
AlphaMixIn.__init__(self)
- self._data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+ if data is None:
+ data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+ self._data = data
self._origin = (0., 0.)
self._scale = (1., 1.)
@@ -129,19 +134,6 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
else:
raise IndexError("Index out of range: %s" % str(item))
- def setVisible(self, visible):
- """Set visibility of item.
-
- :param bool visible: True to display it, False otherwise
- """
- visible = bool(visible)
- # TODO hackish data range implementation
- if self.isVisible() != visible:
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
- super(ImageBase, self).setVisible(visible)
-
def _isPlotLinear(self, plot):
"""Return True if plot only uses linear scale for both of x and y
axes."""
@@ -189,6 +181,15 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
"""
return numpy.array(self._data, copy=copy)
+ def setData(self, data):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ """
+ self._data = data
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
def getRgbaImageData(self, copy=True):
"""Get the displayed RGB(A) image
@@ -215,13 +216,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
origin = float(origin), float(origin)
if origin != self._origin:
self._origin = origin
-
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
+ self._boundsChanged()
self._updated(ItemChangedType.POSITION)
def getScale(self):
@@ -244,13 +239,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn):
if scale != self._scale:
self._scale = scale
-
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
+ self._boundsChanged()
self._updated(ItemChangedType.SCALE)
@@ -258,9 +247,8 @@ class ImageData(ImageBase, ColormapMixIn):
"""Description of a data image with a colormap"""
def __init__(self):
- ImageBase.__init__(self)
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32))
ColormapMixIn.__init__(self)
- self._data = numpy.zeros((0, 0), dtype=numpy.float32)
self._alternativeImage = None
self.__alpha = None
@@ -370,7 +358,6 @@ class ImageData(ImageBase, ColormapMixIn):
_logger.warning(
'Converting complex image to absolute value to plot it.')
data = numpy.absolute(data)
- self._data = data
self._setColormappedData(data, copy=False)
if alternative is not None:
@@ -389,20 +376,14 @@ class ImageData(ImageBase, ColormapMixIn):
alpha = numpy.clip(alpha, 0., 1.)
self.__alpha = alpha
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- self._updated(ItemChangedType.DATA)
+ super().setData(data)
class ImageRgba(ImageBase):
"""Description of an RGB(A) image"""
def __init__(self):
- ImageBase.__init__(self)
+ ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8))
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
@@ -440,15 +421,7 @@ class ImageRgba(ImageBase):
data = numpy.array(data, copy=copy)
assert data.ndim == 3
assert data.shape[-1] in (3, 4)
- self._data = data
-
- # TODO hackish data range implementation
- if self.isVisible():
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- self._updated(ItemChangedType.DATA)
+ super().setData(data)
class MaskImageData(ImageData):
diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py
index ff73fe6..38a1424 100644
--- a/silx/gui/plot/items/roi.py
+++ b/silx/gui/plot/items/roi.py
@@ -36,729 +36,25 @@ __date__ = "28/06/2018"
import logging
import numpy
-import weakref
-from silx.image.shapes import Polygon
-from ....utils.weakref import WeakList
-from ... import qt
from ... import utils
from .. import items
-from ..items import core
from ...colors import rgba
-import silx.utils.deprecation
+from silx.image.shapes import Polygon
from silx.image._boundingbox import _BoundingBox
from ....utils.proxy import docstring
from ..utils.intersections import segments_intersection
+from ._roi_base import _RegionOfInterestBase
+# He following imports have to be exposed by this module
+from ._roi_base import RegionOfInterest
+from ._roi_base import HandleBasedROI
+from ._arc_roi import ArcROI # noqa
+from ._roi_base import InteractionModeMixIn # noqa
+from ._roi_base import RoiInteractionMode # noqa
-logger = logging.getLogger(__name__)
-
-
-class _RegionOfInterestBase(qt.QObject):
- """Base class of 1D and 2D region of interest
-
- :param QObject parent: See QObject
- :param str name: The name of the ROI
- """
-
- sigAboutToBeRemoved = qt.Signal()
- """Signal emitted just before this ROI is removed from its manager."""
-
- sigItemChanged = qt.Signal(object)
- """Signal emitted when item has changed.
-
- It provides a flag describing which property of the item has changed.
- See :class:`ItemChangedType` for flags description.
- """
-
- def __init__(self, parent=None):
- qt.QObject.__init__(self, parent=parent)
- self.__name = ''
-
- def getName(self):
- """Returns the name of the ROI
-
- :return: name of the region of interest
- :rtype: str
- """
- return self.__name
-
- def setName(self, name):
- """Set the name of the ROI
-
- :param str name: name of the region of interest
- """
- name = str(name)
- if self.__name != name:
- self.__name = name
- self._updated(items.ItemChangedType.NAME)
-
- def _updated(self, event=None, checkVisibility=True):
- """Implement Item mix-in update method by updating the plot items
-
- See :class:`~silx.gui.plot.items.Item._updated`
- """
- self.sigItemChanged.emit(event)
-
- def contains(self, position):
- """Returns True if the `position` is in this ROI.
-
- :param tuple[float,float] position: position to check
- :return: True if the value / point is consider to be in the region of
- interest.
- :rtype: bool
- """
- raise NotImplementedError("Base class")
-
-
-class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
- """Object describing a region of interest in a plot.
-
- :param QObject parent:
- The RegionOfInterestManager that created this object
- """
-
- _DEFAULT_LINEWIDTH = 1.
- """Default line width of the curve"""
-
- _DEFAULT_LINESTYLE = '-'
- """Default line style of the curve"""
-
- _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2)
- """Default highlight style of the item"""
-
- ICON, NAME, SHORT_NAME = None, None, None
- """Metadata to describe the ROI in labels, tooltips and widgets
-
- Should be set by inherited classes to custom the ROI manager widget.
- """
-
- sigRegionChanged = qt.Signal()
- """Signal emitted everytime the shape or position of the ROI changes"""
-
- sigEditingStarted = qt.Signal()
- """Signal emitted when the user start editing the roi"""
-
- sigEditingFinished = qt.Signal()
- """Signal emitted when the region edition is finished. During edition
- sigEditionChanged will be emitted several times and
- sigRegionEditionFinished only at end"""
-
- def __init__(self, parent=None):
- # Avoid circular dependency
- from ..tools import roi as roi_tools
- assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
- _RegionOfInterestBase.__init__(self, parent)
- core.HighlightedMixIn.__init__(self)
- self._color = rgba('red')
- self._editable = False
- self._selectable = False
- self._focusProxy = None
- self._visible = True
- self._child = WeakList()
-
- def _connectToPlot(self, plot):
- """Called after connection to a plot"""
- for item in self.getItems():
- # This hack is needed to avoid reentrant call from _disconnectFromPlot
- # to the ROI manager. It also speed up the item tests in _itemRemoved
- item._roiGroup = True
- plot.addItem(item)
-
- def _disconnectFromPlot(self, plot):
- """Called before disconnection from a plot"""
- for item in self.getItems():
- # The item could be already be removed by the plot
- if item.getPlot() is not None:
- del item._roiGroup
- plot.removeItem(item)
-
- def _setItemName(self, item):
- """Helper to generate a unique id to a plot item"""
- legend = "__ROI-%d__%d" % (id(self), id(item))
- item.setName(legend)
-
- def setParent(self, parent):
- """Set the parent of the RegionOfInterest
-
- :param Union[None,RegionOfInterestManager] parent: The new parent
- """
- # Avoid circular dependency
- from ..tools import roi as roi_tools
- if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)):
- raise ValueError('Unsupported parent')
-
- previousParent = self.parent()
- if previousParent is not None:
- previousPlot = previousParent.parent()
- if previousPlot is not None:
- self._disconnectFromPlot(previousPlot)
- super(RegionOfInterest, self).setParent(parent)
- if parent is not None:
- plot = parent.parent()
- if plot is not None:
- self._connectToPlot(plot)
-
- def addItem(self, item):
- """Add an item to the set of this ROI children.
-
- This item will be added and removed to the plot used by the ROI.
-
- If the ROI is already part of a plot, the item will also be added to
- the plot.
-
- It the item do not have a name already, a unique one is generated to
- avoid item collision in the plot.
-
- :param silx.gui.plot.items.Item item: A plot item
- """
- assert item is not None
- self._child.append(item)
- if item.getName() == '':
- self._setItemName(item)
- manager = self.parent()
- if manager is not None:
- plot = manager.parent()
- if plot is not None:
- item._roiGroup = True
- plot.addItem(item)
-
- def removeItem(self, item):
- """Remove an item from this ROI children.
-
- If the item is part of a plot it will be removed too.
-
- :param silx.gui.plot.items.Item item: A plot item
- """
- assert item is not None
- self._child.remove(item)
- plot = item.getPlot()
- if plot is not None:
- del item._roiGroup
- plot.removeItem(item)
-
- def getItems(self):
- """Returns the list of PlotWidget items of this RegionOfInterest.
-
- :rtype: List[~silx.gui.plot.items.Item]
- """
- return tuple(self._child)
-
- @classmethod
- def _getShortName(cls):
- """Return an human readable kind of ROI
-
- :rtype: str
- """
- if hasattr(cls, "SHORT_NAME"):
- name = cls.SHORT_NAME
- if name is None:
- name = cls.__name__
- return name
-
- def getColor(self):
- """Returns the color of this ROI
-
- :rtype: QColor
- """
- return qt.QColor.fromRgbF(*self._color)
-
- def setColor(self, color):
- """Set the color used for this ROI.
-
- :param color: The color to use for ROI shape as
- either a color name, a QColor, a list of uint8 or float in [0, 1].
- """
- color = rgba(color)
- if color != self._color:
- self._color = color
- self._updated(items.ItemChangedType.COLOR)
-
- @silx.utils.deprecation.deprecated(reason='API modification',
- replacement='getName()',
- since_version=0.12)
- def getLabel(self):
- """Returns the label displayed for this ROI.
-
- :rtype: str
- """
- return self.getName()
-
- @silx.utils.deprecation.deprecated(reason='API modification',
- replacement='setName(name)',
- since_version=0.12)
- def setLabel(self, label):
- """Set the label displayed with this ROI.
-
- :param str label: The text label to display
- """
- self.setName(name=label)
-
- def isEditable(self):
- """Returns whether the ROI is editable by the user or not.
-
- :rtype: bool
- """
- return self._editable
-
- def setEditable(self, editable):
- """Set whether the ROI can be changed interactively.
-
- :param bool editable: True to allow edition by the user,
- False to disable.
- """
- editable = bool(editable)
- if self._editable != editable:
- self._editable = editable
- self._updated(items.ItemChangedType.EDITABLE)
-
- def isSelectable(self):
- """Returns whether the ROI is selectable by the user or not.
-
- :rtype: bool
- """
- return self._selectable
-
- def setSelectable(self, selectable):
- """Set whether the ROI can be selected interactively.
-
- :param bool selectable: True to allow selection by the user,
- False to disable.
- """
- selectable = bool(selectable)
- if self._selectable != selectable:
- self._selectable = selectable
- self._updated(items.ItemChangedType.SELECTABLE)
-
- def getFocusProxy(self):
- """Returns the ROI which have to be selected when this ROI is selected,
- else None if no proxy specified.
-
- :rtype: RegionOfInterest
- """
- proxy = self._focusProxy
- if proxy is None:
- return None
- proxy = proxy()
- if proxy is None:
- self._focusProxy = None
- return proxy
-
- def setFocusProxy(self, roi):
- """Set the real ROI which will be selected when this ROI is selected,
- else None to remove the proxy already specified.
-
- :param RegionOfInterest roi: A ROI
- """
- if roi is not None:
- self._focusProxy = weakref.ref(roi)
- else:
- self._focusProxy = None
-
- def isVisible(self):
- """Returns whether the ROI is visible in the plot.
-
- .. note::
- This does not take into account whether or not the plot
- widget itself is visible (unlike :meth:`QWidget.isVisible` which
- checks the visibility of all its parent widgets up to the window)
-
- :rtype: bool
- """
- return self._visible
-
- def setVisible(self, visible):
- """Set whether the plot items associated with this ROI are
- visible in the plot.
-
- :param bool visible: True to show the ROI in the plot, False to
- hide it.
- """
- visible = bool(visible)
- if self._visible != visible:
- self._visible = visible
- self._updated(items.ItemChangedType.VISIBLE)
-
- @classmethod
- def showFirstInteractionShape(cls):
- """Returns True if the shape created by the first interaction and
- managed by the plot have to be visible.
-
- :rtype: bool
- """
- return False
-
- @classmethod
- def getFirstInteractionShape(cls):
- """Returns the shape kind which will be used by the very first
- interaction with the plot.
-
- This interactions are hardcoded inside the plot
-
- :rtype: str
- """
- return cls._plotShape
-
- def setFirstShapePoints(self, points):
- """"Initialize the ROI using the points from the first interaction.
-
- This interaction is constrained by the plot API and only supports few
- shapes.
- """
- raise NotImplementedError()
-
- def creationStarted(self):
- """"Called when the ROI creation interaction was started.
- """
- pass
-
- @docstring(_RegionOfInterestBase)
- def contains(self, position):
- raise NotImplementedError("Base class")
-
- def creationFinalized(self):
- """"Called when the ROI creation interaction was finalized.
- """
- pass
-
- def _updateItemProperty(self, event, source, destination):
- """Update the item property of a destination from an item source.
-
- :param items.ItemChangedType event: Property type to update
- :param silx.gui.plot.items.Item source: The reference for the data
- :param event Union[Item,List[Item]] destination: The item(s) to update
- """
- if not isinstance(destination, (list, tuple)):
- destination = [destination]
- if event == items.ItemChangedType.NAME:
- value = source.getName()
- for d in destination:
- d.setName(value)
- elif event == items.ItemChangedType.EDITABLE:
- value = source.isEditable()
- for d in destination:
- d.setEditable(value)
- elif event == items.ItemChangedType.SELECTABLE:
- value = source.isSelectable()
- for d in destination:
- d._setSelectable(value)
- elif event == items.ItemChangedType.COLOR:
- value = rgba(source.getColor())
- for d in destination:
- d.setColor(value)
- elif event == items.ItemChangedType.LINE_STYLE:
- value = self.getLineStyle()
- for d in destination:
- d.setLineStyle(value)
- elif event == items.ItemChangedType.LINE_WIDTH:
- value = self.getLineWidth()
- for d in destination:
- d.setLineWidth(value)
- elif event == items.ItemChangedType.SYMBOL:
- value = self.getSymbol()
- for d in destination:
- d.setSymbol(value)
- elif event == items.ItemChangedType.SYMBOL_SIZE:
- value = self.getSymbolSize()
- for d in destination:
- d.setSymbolSize(value)
- elif event == items.ItemChangedType.VISIBLE:
- value = self.isVisible()
- for d in destination:
- d.setVisible(value)
- else:
- assert False
-
- def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.HIGHLIGHTED:
- style = self.getCurrentStyle()
- self._updatedStyle(event, style)
- else:
- hilighted = self.isHighlighted()
- if hilighted:
- if event == items.ItemChangedType.HIGHLIGHTED_STYLE:
- style = self.getCurrentStyle()
- self._updatedStyle(event, style)
- else:
- if event in [items.ItemChangedType.COLOR,
- items.ItemChangedType.LINE_STYLE,
- items.ItemChangedType.LINE_WIDTH,
- items.ItemChangedType.SYMBOL,
- items.ItemChangedType.SYMBOL_SIZE]:
- style = self.getCurrentStyle()
- self._updatedStyle(event, style)
- super(RegionOfInterest, self)._updated(event, checkVisibility)
-
- def _updatedStyle(self, event, style):
- """Called when the current displayed style of the ROI was changed.
-
- :param event: The event responsible of the change of the style
- :param items.CurveStyle style: The current style
- """
- pass
-
- def getCurrentStyle(self):
- """Returns the current curve style.
-
- Curve style depends on curve highlighting
-
- :rtype: CurveStyle
- """
- baseColor = rgba(self.getColor())
- if isinstance(self, core.LineMixIn):
- baseLinestyle = self.getLineStyle()
- baseLinewidth = self.getLineWidth()
- else:
- baseLinestyle = self._DEFAULT_LINESTYLE
- baseLinewidth = self._DEFAULT_LINEWIDTH
- if isinstance(self, core.SymbolMixIn):
- baseSymbol = self.getSymbol()
- baseSymbolsize = self.getSymbolSize()
- else:
- baseSymbol = 'o'
- baseSymbolsize = 1
-
- if self.isHighlighted():
- style = self.getHighlightedStyle()
- color = style.getColor()
- linestyle = style.getLineStyle()
- linewidth = style.getLineWidth()
- symbol = style.getSymbol()
- symbolsize = style.getSymbolSize()
-
- return items.CurveStyle(
- color=baseColor if color is None else color,
- linestyle=baseLinestyle if linestyle is None else linestyle,
- linewidth=baseLinewidth if linewidth is None else linewidth,
- symbol=baseSymbol if symbol is None else symbol,
- symbolsize=baseSymbolsize if symbolsize is None else symbolsize)
- else:
- return items.CurveStyle(color=baseColor,
- linestyle=baseLinestyle,
- linewidth=baseLinewidth,
- symbol=baseSymbol,
- symbolsize=baseSymbolsize)
-
- def _editingStarted(self):
- assert self._editable is True
- self.sigEditingStarted.emit()
-
- def _editingFinished(self):
- self.sigEditingFinished.emit()
-
-
-class HandleBasedROI(RegionOfInterest):
- """Manage a ROI based on a set of handles"""
-
- def __init__(self, parent=None):
- RegionOfInterest.__init__(self, parent=parent)
- self._handles = []
- self._posOrigin = None
- self._posPrevious = None
-
- def addUserHandle(self, item=None):
- """
- Add a new free handle to the ROI.
-
- This handle do nothing. It have to be managed by the ROI
- implementing this class.
-
- :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
- add, else None to create a default marker.
- :rtype: silx.gui.plot.items.Marker
- """
- return self.addHandle(item, role="user")
-
- def addLabelHandle(self, item=None):
- """
- Add a new label handle to the ROI.
-
- This handle is not draggable nor selectable.
-
- It is displayed without symbol, but it is always visible anyway
- the ROI is editable, in order to display text.
-
- :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
- add, else None to create a default marker.
- :rtype: silx.gui.plot.items.Marker
- """
- return self.addHandle(item, role="label")
-
- def addTranslateHandle(self, item=None):
- """
- Add a new translate handle to the ROI.
-
- Dragging translate handles affect the position position of the ROI
- but not the shape itself.
-
- :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
- add, else None to create a default marker.
- :rtype: silx.gui.plot.items.Marker
- """
- return self.addHandle(item, role="translate")
-
- def addHandle(self, item=None, role="default"):
- """
- Add a new handle to the ROI.
-
- Dragging handles while affect the position or the shape of the
- ROI.
-
- :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
- add, else None to create a default marker.
- :rtype: silx.gui.plot.items.Marker
- """
- if item is None:
- item = items.Marker()
- color = rgba(self.getColor())
- color = self._computeHandleColor(color)
- item.setColor(color)
- if role == "default":
- item.setSymbol("s")
- elif role == "user":
- pass
- elif role == "translate":
- item.setSymbol("+")
- elif role == "label":
- item.setSymbol("")
-
- if role == "user":
- pass
- elif role == "label":
- item._setSelectable(False)
- item._setDraggable(False)
- item.setVisible(True)
- else:
- self.__updateEditable(item, self.isEditable(), remove=False)
- item._setSelectable(False)
-
- self._handles.append((item, role))
- self.addItem(item)
- return item
-
- def removeHandle(self, handle):
- data = [d for d in self._handles if d[0] is handle][0]
- self._handles.remove(data)
- role = data[1]
- if role not in ["user", "label"]:
- if self.isEditable():
- self.__updateEditable(handle, False)
- self.removeItem(handle)
-
- def getHandles(self):
- """Returns the list of handles of this HandleBasedROI.
-
- :rtype: List[~silx.gui.plot.items.Marker]
- """
- return tuple(data[0] for data in self._handles)
-
- def _updated(self, event=None, checkVisibility=True):
- """Implement Item mix-in update method by updating the plot items
-
- See :class:`~silx.gui.plot.items.Item._updated`
- """
- if event == items.ItemChangedType.NAME:
- self._updateText(self.getName())
- elif event == items.ItemChangedType.VISIBLE:
- for item, role in self._handles:
- visible = self.isVisible()
- editionVisible = visible and self.isEditable()
- if role not in ["user", "label"]:
- item.setVisible(editionVisible)
- else:
- item.setVisible(visible)
- elif event == items.ItemChangedType.EDITABLE:
- for item, role in self._handles:
- editable = self.isEditable()
- if role not in ["user", "label"]:
- self.__updateEditable(item, editable)
- super(HandleBasedROI, self)._updated(event, checkVisibility)
-
- def _updatedStyle(self, event, style):
- super(HandleBasedROI, self)._updatedStyle(event, style)
-
- # Update color of shape items in the plot
- color = rgba(self.getColor())
- handleColor = self._computeHandleColor(color)
- for item, role in self._handles:
- if role == 'user':
- pass
- elif role == 'label':
- item.setColor(color)
- else:
- item.setColor(handleColor)
-
- def __updateEditable(self, handle, editable, remove=True):
- # NOTE: visibility change emit a position update event
- handle.setVisible(editable and self.isVisible())
- handle._setDraggable(editable)
- if editable:
- handle.sigDragStarted.connect(self._handleEditingStarted)
- handle.sigItemChanged.connect(self._handleEditingUpdated)
- handle.sigDragFinished.connect(self._handleEditingFinished)
- else:
- if remove:
- handle.sigDragStarted.disconnect(self._handleEditingStarted)
- handle.sigItemChanged.disconnect(self._handleEditingUpdated)
- handle.sigDragFinished.disconnect(self._handleEditingFinished)
-
- def _handleEditingStarted(self):
- super(HandleBasedROI, self)._editingStarted()
- handle = self.sender()
- self._posOrigin = numpy.array(handle.getPosition())
- self._posPrevious = numpy.array(self._posOrigin)
- self.handleDragStarted(handle, self._posOrigin)
-
- def _handleEditingUpdated(self):
- if self._posOrigin is None:
- # Avoid to handle events when visibility change
- return
- handle = self.sender()
- current = numpy.array(handle.getPosition())
- self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current)
- self._posPrevious = current
-
- def _handleEditingFinished(self):
- handle = self.sender()
- current = numpy.array(handle.getPosition())
- self.handleDragFinished(handle, self._posOrigin, current)
- self._posPrevious = None
- self._posOrigin = None
- super(HandleBasedROI, self)._editingFinished()
-
- def isHandleBeingDragged(self):
- """Returns True if one of the handles is currently being dragged.
-
- :rtype: bool
- """
- return self._posOrigin is not None
-
- def handleDragStarted(self, handle, origin):
- """Called when an handler drag started"""
- pass
-
- def handleDragUpdated(self, handle, origin, previous, current):
- """Called when an handle drag position changed"""
- pass
-
- def handleDragFinished(self, handle, origin, current):
- """Called when an handle drag finished"""
- pass
-
- def _computeHandleColor(self, color):
- """Returns the anchor color from the base ROI color
- :param Union[numpy.array,Tuple,List]: color
- :rtype: Union[numpy.array,Tuple,List]
- """
- return color[:3] + (0.5,)
-
- def _updateText(self, text):
- """Update the text displayed by this ROI
-
- :param str text: A text
- """
- pass
+logger = logging.getLogger(__name__)
class PointROI(RegionOfInterest, items.SymbolMixIn):
@@ -821,7 +117,8 @@ class PointROI(RegionOfInterest, items.SymbolMixIn):
@docstring(_RegionOfInterestBase)
def contains(self, position):
- raise NotImplementedError('Base class')
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] and position[1] == roiPos[1]
def _pointPositionChanged(self, event):
"""Handle position changed events of the marker"""
@@ -1022,11 +319,12 @@ class LineROI(HandleBasedROI, items.LineMixIn):
top_left = position[0], position[1] + 1
top_right = position[0] + 1, position[1] + 1
- line_pt1 = self._points[0]
- line_pt2 = self._points[1]
+ points = self.__shape.getPoints()
+ line_pt1 = points[0]
+ line_pt2 = points[1]
- bb1 = _BoundingBox.from_points(self._points)
- if bb1.contains(position) is False:
+ bb1 = _BoundingBox.from_points(points)
+ if not bb1.contains(position):
return False
return (
@@ -1038,7 +336,7 @@ class LineROI(HandleBasedROI, items.LineMixIn):
seg2_start_pt=top_right, seg2_end_pt=top_left) or
segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
seg2_start_pt=top_left, seg2_end_pt=bottom_left)
- )
+ ) is not None
def __str__(self):
start, end = self.getEndPoints()
@@ -1106,7 +404,7 @@ class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
@docstring(_RegionOfInterestBase)
def contains(self, position):
- return position[1] == self.getPosition()[1]
+ return position[1] == self.getPosition()
def _linePositionChanged(self, event):
"""Handle position changed events of the marker"""
@@ -1175,7 +473,7 @@ class VerticalLineROI(RegionOfInterest, items.LineMixIn):
@docstring(RegionOfInterest)
def contains(self, position):
- return position[0] == self.getPosition()[0]
+ return position[0] == self.getPosition()
def _linePositionChanged(self, event):
"""Handle position changed events of the marker"""
@@ -1515,6 +813,10 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
center = self.getCenter()
self.setRadius(numpy.linalg.norm(center - current))
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return numpy.linalg.norm(self.getCenter() - position) <= self.getRadius()
+
def __str__(self):
center = self.getCenter()
radius = self.getRadius()
@@ -1726,7 +1028,7 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
orientation = self.getOrientation()
if self._radius[1] > self._radius[0]:
# _handleAxis1 is the major axis
- orientation -= numpy.pi/2
+ orientation -= numpy.pi / 2
point0 = numpy.array([center[0] + self._radius[0] * numpy.cos(orientation),
center[1] + self._radius[0] * numpy.sin(orientation)])
@@ -1760,13 +1062,13 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
if handle is self._handleAxis1:
if self._radius[0] > distance:
# _handleAxis1 is not the major axis, rotate -90 degrees
- orientation -= numpy.pi/2
+ orientation -= numpy.pi / 2
radius = self._radius[0], distance
else: # _handleAxis0
if self._radius[1] > distance:
# _handleAxis0 is not the major axis, rotate +90 degrees
- orientation += numpy.pi/2
+ orientation += numpy.pi / 2
radius = distance, self._radius[1]
self.setGeometry(radius=radius, orientation=orientation)
@@ -1776,6 +1078,14 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
if event is items.ItemChangedType.POSITION:
self._updateGeometry()
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ major, minor = self.getMajorRadius(), self.getMinorRadius()
+ delta = self.getOrientation()
+ x, y = position - self.getCenter()
+ return ((x*numpy.cos(delta) + y*numpy.sin(delta))**2/major**2 +
+ (x*numpy.sin(delta) - y*numpy.cos(delta))**2/minor**2) <= 1
+
def __str__(self):
center = self.getCenter()
major = self.getMajorRadius()
@@ -1987,682 +1297,6 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
self._polygon_shape = None
-class ArcROI(HandleBasedROI, items.LineMixIn):
- """A ROI identifying an arc of a circle with a width.
-
- This ROI provides
- - 3 handle to control the curvature
- - 1 handle to control the weight
- - 1 anchor to translate the shape.
- """
-
- ICON = 'add-shape-arc'
- NAME = 'arc ROI'
- SHORT_NAME = "arc"
- """Metadata for this kind of ROI"""
-
- _plotShape = "line"
- """Plot shape which is used for the first interaction"""
-
- class _Geometry:
- def __init__(self):
- self.center = None
- self.startPoint = None
- self.endPoint = None
- self.radius = None
- self.weight = None
- self.startAngle = None
- self.endAngle = None
- self._closed = None
-
- @classmethod
- def createEmpty(cls):
- zero = numpy.array([0, 0])
- return cls.create(zero, zero.copy(), zero.copy(), 0, 0, 0, 0)
-
- @classmethod
- def createRect(cls, startPoint, endPoint, weight):
- return cls.create(None, startPoint, endPoint, None, weight, None, None, False)
-
- @classmethod
- def createCircle(cls, center, startPoint, endPoint, radius,
- weight, startAngle, endAngle):
- return cls.create(center, startPoint, endPoint, radius,
- weight, startAngle, endAngle, True)
-
- @classmethod
- def create(cls, center, startPoint, endPoint, radius,
- weight, startAngle, endAngle, closed=False):
- g = cls()
- g.center = center
- g.startPoint = startPoint
- g.endPoint = endPoint
- g.radius = radius
- g.weight = weight
- g.startAngle = startAngle
- g.endAngle = endAngle
- g._closed = closed
- return g
-
- def withWeight(self, weight):
- """Create a new geometry with another weight
- """
- return self.create(self.center, self.startPoint, self.endPoint,
- self.radius, weight,
- self.startAngle, self.endAngle, self._closed)
-
- def withRadius(self, radius):
- """Create a new geometry with another radius.
-
- The weight and the center is conserved.
- """
- startPoint = self.center + (self.startPoint - self.center) / self.radius * radius
- endPoint = self.center + (self.endPoint - self.center) / self.radius * radius
- return self.create(self.center, startPoint, endPoint,
- radius, self.weight,
- self.startAngle, self.endAngle, self._closed)
-
- def translated(self, x, y):
- delta = numpy.array([x, y])
- center = None if self.center is None else self.center + delta
- startPoint = None if self.startPoint is None else self.startPoint + delta
- endPoint = None if self.endPoint is None else self.endPoint + delta
- return self.create(center, startPoint, endPoint,
- self.radius, self.weight,
- self.startAngle, self.endAngle, self._closed)
-
- def getKind(self):
- """Returns the kind of shape defined"""
- if self.center is None:
- return "rect"
- elif numpy.isnan(self.startAngle):
- return "point"
- elif self.isClosed():
- if self.weight <= 0 or self.weight * 0.5 >= self.radius:
- return "circle"
- else:
- return "donut"
- else:
- if self.weight * 0.5 < self.radius:
- return "arc"
- else:
- return "camembert"
-
- def isClosed(self):
- """Returns True if the geometry is a circle like"""
- if self._closed is not None:
- return self._closed
- delta = numpy.abs(self.endAngle - self.startAngle)
- self._closed = numpy.isclose(delta, numpy.pi * 2)
- return self._closed
-
- def __str__(self):
- return str((self.center,
- self.startPoint,
- self.endPoint,
- self.radius,
- self.weight,
- self.startAngle,
- self.endAngle,
- self._closed))
-
- def __init__(self, parent=None):
- HandleBasedROI.__init__(self, parent=parent)
- items.LineMixIn.__init__(self)
- self._geometry = self._Geometry.createEmpty()
- self._handleLabel = self.addLabelHandle()
-
- self._handleStart = self.addHandle()
- self._handleStart.setSymbol("o")
- self._handleMid = self.addHandle()
- self._handleMid.setSymbol("o")
- self._handleEnd = self.addHandle()
- self._handleEnd.setSymbol("o")
- self._handleWeight = self.addHandle()
- self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint)
- self._handleMove = self.addTranslateHandle()
-
- shape = items.Shape("polygon")
- shape.setPoints([[0, 0], [0, 0]])
- shape.setColor(rgba(self.getColor()))
- shape.setFill(False)
- shape.setOverlay(True)
- shape.setLineStyle(self.getLineStyle())
- shape.setLineWidth(self.getLineWidth())
- self.__shape = shape
- self.addItem(shape)
-
- def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.VISIBLE:
- self._updateItemProperty(event, self, self.__shape)
- super(ArcROI, self)._updated(event, checkVisibility)
-
- def _updatedStyle(self, event, style):
- super(ArcROI, self)._updatedStyle(event, style)
- self.__shape.setColor(style.getColor())
- self.__shape.setLineStyle(style.getLineStyle())
- self.__shape.setLineWidth(style.getLineWidth())
-
- def setFirstShapePoints(self, points):
- """"Initialize the ROI using the points from the first interaction.
-
- This interaction is constrained by the plot API and only supports few
- shapes.
- """
- # The first shape is a line
- point0 = points[0]
- point1 = points[1]
-
- # Compute a non collinear point for the curvature
- center = (point1 + point0) * 0.5
- normal = point1 - center
- normal = numpy.array((normal[1], -normal[0]))
- defaultCurvature = numpy.pi / 5.0
- weightCoef = 0.20
- mid = center - normal * defaultCurvature
- distance = numpy.linalg.norm(point0 - point1)
- weight = distance * weightCoef
-
- geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight)
- self._geometry = geometry
- self._updateHandles()
-
- def _updateText(self, text):
- self._handleLabel.setText(text)
-
- def _updateMidHandle(self):
- """Keep the same geometry, but update the location of the control
- points.
-
- So calling this function do not trigger sigRegionChanged.
- """
- geometry = self._geometry
-
- if geometry.isClosed():
- start = numpy.array(self._handleStart.getPosition())
- geometry.endPoint = start
- with utils.blockSignals(self._handleEnd):
- self._handleEnd.setPosition(*start)
- midPos = geometry.center + geometry.center - start
- else:
- if geometry.center is None:
- midPos = geometry.startPoint * 0.66 + geometry.endPoint * 0.34
- else:
- midAngle = geometry.startAngle * 0.66 + geometry.endAngle * 0.34
- vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
- midPos = geometry.center + geometry.radius * vector
-
- with utils.blockSignals(self._handleMid):
- self._handleMid.setPosition(*midPos)
-
- def _updateWeightHandle(self):
- geometry = self._geometry
- if geometry.center is None:
- # rectangle
- center = (geometry.startPoint + geometry.endPoint) * 0.5
- normal = geometry.endPoint - geometry.startPoint
- normal = numpy.array((normal[1], -normal[0]))
- distance = numpy.linalg.norm(normal)
- if distance != 0:
- normal = normal / distance
- weightPos = center + normal * geometry.weight * 0.5
- else:
- if geometry.isClosed():
- midAngle = geometry.startAngle + numpy.pi * 0.5
- elif geometry.center is not None:
- midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
- vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
- weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
-
- with utils.blockSignals(self._handleWeight):
- self._handleWeight.setPosition(*weightPos)
-
- def _getWeightFromHandle(self, weightPos):
- geometry = self._geometry
- if geometry.center is None:
- # rectangle
- center = (geometry.startPoint + geometry.endPoint) * 0.5
- return numpy.linalg.norm(center - weightPos) * 2
- else:
- distance = numpy.linalg.norm(geometry.center - weightPos)
- return abs(distance - geometry.radius) * 2
-
- def _updateHandles(self):
- geometry = self._geometry
- with utils.blockSignals(self._handleStart):
- self._handleStart.setPosition(*geometry.startPoint)
- with utils.blockSignals(self._handleEnd):
- self._handleEnd.setPosition(*geometry.endPoint)
-
- self._updateMidHandle()
- self._updateWeightHandle()
-
- self._updateShape()
-
- def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False):
- """Update the curvature using 3 control points in the curve
-
- :param bool updateCurveHandles: If False curve handles are already at
- the right location
- """
- if updateCurveHandles:
- with utils.blockSignals(self._handleStart):
- self._handleStart.setPosition(*start)
- with utils.blockSignals(self._handleMid):
- self._handleMid.setPosition(*mid)
- with utils.blockSignals(self._handleEnd):
- self._handleEnd.setPosition(*end)
-
- if checkClosed:
- closed = self._isCloseInPixel(start, end)
- else:
- closed = self._geometry.isClosed()
-
- weight = self._geometry.weight
- geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed)
- self._geometry = geometry
-
- self._updateWeightHandle()
- self._updateShape()
-
- def handleDragUpdated(self, handle, origin, previous, current):
- if handle is self._handleStart:
- mid = numpy.array(self._handleMid.getPosition())
- end = numpy.array(self._handleEnd.getPosition())
- self._updateCurvature(current, mid, end,
- checkClosed=True, updateCurveHandles=False)
- elif handle is self._handleMid:
- if self._geometry.isClosed():
- radius = numpy.linalg.norm(self._geometry.center - current)
- self._geometry = self._geometry.withRadius(radius)
- self._updateHandles()
- else:
- start = numpy.array(self._handleStart.getPosition())
- end = numpy.array(self._handleEnd.getPosition())
- self._updateCurvature(start, current, end, updateCurveHandles=False)
- elif handle is self._handleEnd:
- start = numpy.array(self._handleStart.getPosition())
- mid = numpy.array(self._handleMid.getPosition())
- self._updateCurvature(start, mid, current,
- checkClosed=True, updateCurveHandles=False)
- elif handle is self._handleWeight:
- weight = self._getWeightFromHandle(current)
- self._geometry = self._geometry.withWeight(weight)
- self._updateShape()
- elif handle is self._handleMove:
- delta = current - previous
- self.translate(*delta)
-
- def _isCloseInPixel(self, point1, point2):
- manager = self.parent()
- if manager is None:
- return False
- plot = manager.parent()
- if plot is None:
- return False
- point1 = plot.dataToPixel(*point1)
- if point1 is None:
- return False
- point2 = plot.dataToPixel(*point2)
- if point2 is None:
- return False
- return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15
-
- def _normalizeGeometry(self):
- """Keep the same phisical geometry, but with normalized parameters.
- """
- geometry = self._geometry
- if geometry.weight * 0.5 >= geometry.radius:
- radius = (geometry.weight * 0.5 + geometry.radius) * 0.5
- geometry = geometry.withRadius(radius)
- geometry = geometry.withWeight(radius * 2)
- self._geometry = geometry
- return True
- return False
-
- def handleDragFinished(self, handle, origin, current):
- if handle in [self._handleStart, self._handleMid, self._handleEnd]:
- if self._normalizeGeometry():
- self._updateHandles()
- else:
- self._updateMidHandle()
- if self._geometry.isClosed():
- self._handleStart.setSymbol("x")
- self._handleEnd.setSymbol("x")
- else:
- self._handleStart.setSymbol("o")
- self._handleEnd.setSymbol("o")
-
- def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None):
- """Returns the geometry of the object"""
- if closed or (closed is None and numpy.allclose(start, end)):
- # Special arc: It's a closed circle
- center = (start + mid) * 0.5
- radius = numpy.linalg.norm(start - center)
- v = start - center
- startAngle = numpy.angle(complex(v[0], v[1]))
- endAngle = startAngle + numpy.pi * 2.0
- return self._Geometry.createCircle(center, start, end, radius,
- weight, startAngle, endAngle)
-
- elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5:
- # Degenerated arc, it's a rectangle
- return self._Geometry.createRect(start, end, weight)
- else:
- center, radius = self._circleEquation(start, mid, end)
- v = start - center
- startAngle = numpy.angle(complex(v[0], v[1]))
- v = mid - center
- midAngle = numpy.angle(complex(v[0], v[1]))
- v = end - center
- endAngle = numpy.angle(complex(v[0], v[1]))
-
- # Is it clockwise or anticlockwise
- relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi)
- relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi)
- if relativeMid < relativeEnd:
- if endAngle < startAngle:
- endAngle += 2 * numpy.pi
- else:
- if endAngle > startAngle:
- endAngle -= 2 * numpy.pi
-
- return self._Geometry.create(center, start, end,
- radius, weight, startAngle, endAngle)
-
- def _createShapeFromGeometry(self, geometry):
- kind = geometry.getKind()
- if kind == "rect":
- # It is not an arc
- # but we can display it as an intermediate shape
- normal = (geometry.endPoint - geometry.startPoint)
- normal = numpy.array((normal[1], -normal[0]))
- distance = numpy.linalg.norm(normal)
- if distance != 0:
- normal /= distance
- points = numpy.array([
- geometry.startPoint + normal * geometry.weight * 0.5,
- geometry.endPoint + normal * geometry.weight * 0.5,
- geometry.endPoint - normal * geometry.weight * 0.5,
- geometry.startPoint - normal * geometry.weight * 0.5])
- elif kind == "point":
- # It is not an arc
- # but we can display it as an intermediate shape
- # NOTE: At least 2 points are expected
- points = numpy.array([geometry.startPoint, geometry.startPoint])
- elif kind == "circle":
- outerRadius = geometry.radius + geometry.weight * 0.5
- angles = numpy.arange(0, 2 * numpy.pi, 0.1)
- # It's a circle
- points = []
- numpy.append(angles, angles[-1])
- for angle in angles:
- direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
- points.append(geometry.center + direction * outerRadius)
- points = numpy.array(points)
- elif kind == "donut":
- innerRadius = geometry.radius - geometry.weight * 0.5
- outerRadius = geometry.radius + geometry.weight * 0.5
- angles = numpy.arange(0, 2 * numpy.pi, 0.1)
- # It's a donut
- points = []
- # NOTE: NaN value allow to create 2 separated circle shapes
- # using a single plot item. It's a kind of cheat
- points.append(numpy.array([float("nan"), float("nan")]))
- for angle in angles:
- direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
- points.insert(0, geometry.center + direction * innerRadius)
- points.append(geometry.center + direction * outerRadius)
- points.append(numpy.array([float("nan"), float("nan")]))
- points = numpy.array(points)
- else:
- innerRadius = geometry.radius - geometry.weight * 0.5
- outerRadius = geometry.radius + geometry.weight * 0.5
-
- delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
- if geometry.startAngle == geometry.endAngle:
- # Degenerated, it's a line (single radius)
- angle = geometry.startAngle
- direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
- points = []
- points.append(geometry.center + direction * innerRadius)
- points.append(geometry.center + direction * outerRadius)
- return numpy.array(points)
-
- angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta)
- if angles[-1] != geometry.endAngle:
- angles = numpy.append(angles, geometry.endAngle)
-
- if kind == "camembert":
- # It's a part of camembert
- points = []
- points.append(geometry.center)
- points.append(geometry.startPoint)
- delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
- for angle in angles:
- direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
- points.append(geometry.center + direction * outerRadius)
- points.append(geometry.endPoint)
- points.append(geometry.center)
- elif kind == "arc":
- # It's a part of donut
- points = []
- points.append(geometry.startPoint)
- for angle in angles:
- direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
- points.insert(0, geometry.center + direction * innerRadius)
- points.append(geometry.center + direction * outerRadius)
- points.insert(0, geometry.endPoint)
- points.append(geometry.endPoint)
- else:
- assert False
-
- points = numpy.array(points)
-
- return points
-
- def _updateShape(self):
- geometry = self._geometry
- points = self._createShapeFromGeometry(geometry)
- self.__shape.setPoints(points)
-
- index = numpy.nanargmin(points[:, 1])
- pos = points[index]
- with utils.blockSignals(self._handleLabel):
- self._handleLabel.setPosition(pos[0], pos[1])
-
- if geometry.center is None:
- movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66
- elif (geometry.isClosed()
- or abs(geometry.endAngle - geometry.startAngle) > numpy.pi * 0.7):
- movePos = geometry.center
- else:
- moveAngle = geometry.startAngle * 0.34 + geometry.endAngle * 0.66
- vector = numpy.array([numpy.cos(moveAngle), numpy.sin(moveAngle)])
- movePos = geometry.center + geometry.radius * vector
-
- with utils.blockSignals(self._handleMove):
- self._handleMove.setPosition(*movePos)
-
- self.sigRegionChanged.emit()
-
- def getGeometry(self):
- """Returns a tuple containing the geometry of this ROI
-
- It is a symmetric function of :meth:`setGeometry`.
-
- If `startAngle` is smaller than `endAngle` the rotation is clockwise,
- else the rotation is anticlockwise.
-
- :rtype: Tuple[numpy.ndarray,float,float,float,float]
- :raise ValueError: In case the ROI can't be represented as section of
- a circle
- """
- geometry = self._geometry
- if geometry.center is None:
- raise ValueError("This ROI can't be represented as a section of circle")
- return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle
-
- def isClosed(self):
- """Returns true if the arc is a closed shape, like a circle or a donut.
-
- :rtype: bool
- """
- return self._geometry.isClosed()
-
- def getCenter(self):
- """Returns the center of the circle used to draw arcs of this ROI.
-
- This center is usually outside the the shape itself.
-
- :rtype: numpy.ndarray
- """
- return self._geometry.center
-
- def getStartAngle(self):
- """Returns the angle of the start of the section of this ROI (in radian).
-
- If `startAngle` is smaller than `endAngle` the rotation is clockwise,
- else the rotation is anticlockwise.
-
- :rtype: float
- """
- return self._geometry.startAngle
-
- def getEndAngle(self):
- """Returns the angle of the end of the section of this ROI (in radian).
-
- If `startAngle` is smaller than `endAngle` the rotation is clockwise,
- else the rotation is anticlockwise.
-
- :rtype: float
- """
- return self._geometry.endAngle
-
- def getInnerRadius(self):
- """Returns the radius of the smaller arc used to draw this ROI.
-
- :rtype: float
- """
- geometry = self._geometry
- radius = geometry.radius - geometry.weight * 0.5
- if radius < 0:
- radius = 0
- return radius
-
- def getOuterRadius(self):
- """Returns the radius of the bigger arc used to draw this ROI.
-
- :rtype: float
- """
- geometry = self._geometry
- radius = geometry.radius + geometry.weight * 0.5
- return radius
-
- def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle):
- """
- Set the geometry of this arc.
-
- :param numpy.ndarray center: Center of the circle.
- :param float innerRadius: Radius of the smaller arc of the section.
- :param float outerRadius: Weight of the bigger arc of the section.
- It have to be bigger than `innerRadius`
- :param float startAngle: Location of the start of the section (in radian)
- :param float endAngle: Location of the end of the section (in radian).
- If `startAngle` is smaller than `endAngle` the rotation is clockwise,
- else the rotation is anticlockwise.
- """
- assert(innerRadius <= outerRadius)
- assert(numpy.abs(startAngle - endAngle) <= 2 * numpy.pi)
- center = numpy.array(center)
- radius = (innerRadius + outerRadius) * 0.5
- weight = outerRadius - innerRadius
-
- vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
- startPoint = center + vector * radius
- vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
- endPoint = center + vector * radius
-
- geometry = self._Geometry.create(center, startPoint, endPoint,
- radius, weight,
- startAngle, endAngle, closed=None)
- self._geometry = geometry
- self._updateHandles()
-
- @docstring(HandleBasedROI)
- def contains(self, position):
- # first check distance, fastest
- center = self.getCenter()
- distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2)
- is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius()
- if not is_in_distance:
- return False
- rel_pos = position[1] - center[1], position[0] - center[0]
- angle = numpy.arctan2(*rel_pos)
- start_angle = self.getStartAngle()
- end_angle = self.getEndAngle()
-
- if start_angle < end_angle:
- # I never succeed to find a condition where start_angle < end_angle
- # so this is untested
- is_in_angle = start_angle <= angle <= end_angle
- else:
- if end_angle < -numpy.pi and angle > 0:
- angle = angle - (numpy.pi *2.0)
- is_in_angle = end_angle <= angle <= start_angle
- return is_in_angle
-
- def translate(self, x, y):
- self._geometry = self._geometry.translated(x, y)
- self._updateHandles()
-
- def _arcCurvatureMarkerConstraint(self, x, y):
- """Curvature marker remains on perpendicular bisector"""
- geometry = self._geometry
- if geometry.center is None:
- center = (geometry.startPoint + geometry.endPoint) * 0.5
- vector = geometry.startPoint - geometry.endPoint
- vector = numpy.array((vector[1], -vector[0]))
- vdist = numpy.linalg.norm(vector)
- if vdist != 0:
- normal = numpy.array((vector[1], -vector[0])) / vdist
- else:
- normal = numpy.array((0, 0))
- else:
- if geometry.isClosed():
- midAngle = geometry.startAngle + numpy.pi * 0.5
- else:
- midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
- normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
- center = geometry.center
- dist = numpy.dot(normal, (numpy.array((x, y)) - center))
- dist = numpy.clip(dist, geometry.radius, geometry.radius * 2)
- x, y = center + dist * normal
- return x, y
-
- @staticmethod
- def _circleEquation(pt1, pt2, pt3):
- """Circle equation from 3 (x, y) points
-
- :return: Position of the center of the circle and the radius
- :rtype: Tuple[Tuple[float,float],float]
- """
- x, y, z = complex(*pt1), complex(*pt2), complex(*pt3)
- w = z - x
- w /= y - x
- c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x
- return numpy.array((-c.real, -c.imag)), abs(c + x)
-
- def __str__(self):
- try:
- center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry()
- params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle
- params = 'center: %f %f; radius: %f %f; angles: %f %f' % params
- except ValueError:
- params = "invalid"
- return "%s(%s)" % (self.__class__.__name__, params)
-
-
class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an horizontal range in a 1D plot."""
@@ -2875,6 +1509,10 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
marker = self.sender()
self.setCenter(marker.getXPosition())
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return self.getMin() <= position[0] <= self.getMax()
+
def __str__(self):
vrange = self.getRange()
params = 'min: %f; max: %f' % vrange
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index 5e7d65b..fd7cfae 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -332,6 +332,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
@docstring(ScatterVisualizationMixIn)
def setVisualizationParameter(self, parameter, value):
+ parameter = self.VisualizationParameter.from_value(parameter)
+
if super(Scatter, self).setVisualizationParameter(parameter, value):
if parameter in (self.VisualizationParameter.GRID_BOUNDS,
self.VisualizationParameter.GRID_MAJOR_ORDER,
@@ -339,8 +341,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
self.__cacheRegularGridInfo = None
if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION):
- if parameter == self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
self.__cacheHistogramInfo = None # Clean-up cache
if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
self._updateColormappedData()
@@ -351,7 +355,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
@docstring(ScatterVisualizationMixIn)
def getCurrentVisualizationParameter(self, parameter):
value = self.getVisualizationParameter(parameter)
- if value is not None:
+ if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or
+ value is not None):
return value # Value has been set, return it
elif parameter is self.VisualizationParameter.GRID_BOUNDS:
@@ -452,6 +457,12 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
ranges = (tuple(min_max(y, finite=True)),
tuple(min_max(x, finite=True)))
+ rangesHint = self.getVisualizationParameter(
+ self.VisualizationParameter.DATA_BOUNDS_HINT)
+ if rangesHint is not None:
+ ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax))
+ for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint))
+
points = numpy.transpose(numpy.array((y, x)))
counts, sums, bin_edges = Histogramnd(
points,
@@ -850,7 +861,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if numpy.any(clipped):
# copy to keep original array and convert to float
- value = numpy.array(value, copy=True, dtype=numpy.float)
+ value = numpy.array(value, copy=True, dtype=numpy.float64)
value[clipped] = numpy.nan
x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py
index 26aa03b..955dfe3 100644
--- a/silx/gui/plot/items/shape.py
+++ b/silx/gui/plot/items/shape.py
@@ -36,7 +36,9 @@ import numpy
import six
from ... import colors
-from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn
+from .core import (
+ Item, DataItem,
+ ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn)
_logger = logging.getLogger(__name__)
@@ -154,7 +156,7 @@ class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
self._updated(ItemChangedType.LINE_BG_COLOR)
-class BoundingRect(Item, YAxisMixIn):
+class BoundingRect(DataItem, YAxisMixIn):
"""An invisible shape which enforce the plot view to display the defined
space on autoscale.
@@ -166,21 +168,10 @@ class BoundingRect(Item, YAxisMixIn):
"""
def __init__(self):
- Item.__init__(self)
+ DataItem.__init__(self)
YAxisMixIn.__init__(self)
self.__bounds = None
- def _updated(self, event=None, checkVisibility=True):
- if event in (ItemChangedType.YAXIS,
- ItemChangedType.VISIBLE,
- ItemChangedType.DATA):
- # TODO hackish data range implementation
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- super(BoundingRect, self)._updated(event, checkVisibility)
-
def setBounds(self, rect):
"""Set the bounding box of this item in data coordinates
@@ -193,6 +184,7 @@ class BoundingRect(Item, YAxisMixIn):
if rect != self.__bounds:
self.__bounds = rect
+ self._boundsChanged()
self._updated(ItemChangedType.DATA)
def _getBounds(self):
@@ -217,7 +209,7 @@ class BoundingRect(Item, YAxisMixIn):
return self.__bounds
-class _BaseExtent(Item):
+class _BaseExtent(DataItem):
"""Base class for :class:`XAxisExtent` and :class:`YAxisExtent`.
:param str axis: Either 'x' or 'y'.
@@ -225,20 +217,10 @@ class _BaseExtent(Item):
def __init__(self, axis='x'):
assert axis in ('x', 'y')
- Item.__init__(self)
+ DataItem.__init__(self)
self.__axis = axis
self.__range = 1., 100.
- def _updated(self, event=None, checkVisibility=True):
- if event in (ItemChangedType.VISIBLE,
- ItemChangedType.DATA):
- # TODO hackish data range implementation
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- super(_BaseExtent, self)._updated(event, checkVisibility)
-
def setRange(self, min_, max_):
"""Set the range of the extent of this item in data coordinates.
@@ -254,6 +236,7 @@ class _BaseExtent(Item):
if range_ != self.__range:
self.__range = range_
+ self._boundsChanged()
self._updated(ItemChangedType.DATA)
def getRange(self):
diff --git a/silx/gui/plot/matplotlib/__init__.py b/silx/gui/plot/matplotlib/__init__.py
index f42bf53..e787240 100644
--- a/silx/gui/plot/matplotlib/__init__.py
+++ b/silx/gui/plot/matplotlib/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,49 +23,15 @@
#
# ###########################################################################*/
-from __future__ import absolute_import
-
-"""This module initializes matplotlib and sets-up the backend to use.
-
-It MUST be imported prior to any other import of matplotlib.
-
-It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
-to the used backend.
-"""
-
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "02/05/2018"
-
-
-from pkg_resources import parse_version
-import matplotlib
-
-from ... import qt
-
-
-def _matplotlib_use(backend, force):
- """Wrapper of `matplotlib.use` to set-up backend.
-
- It adds extra initialization for PySide and PySide2 with matplotlib < 2.2.
- """
- # This is kept for compatibility with matplotlib < 2.2
- if parse_version(matplotlib.__version__) < parse_version('2.2'):
- if qt.BINDING == 'PySide':
- matplotlib.rcParams['backend.qt4'] = 'PySide'
- if qt.BINDING == 'PySide2':
- matplotlib.rcParams['backend.qt5'] = 'PySide2'
-
- matplotlib.use(backend, force=force)
-
+__date__ = "15/07/2020"
-if qt.BINDING in ('PyQt4', 'PySide'):
- _matplotlib_use('Qt4Agg', force=False)
- from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa
+from silx.utils.deprecation import deprecated_warning
-elif qt.BINDING in ('PyQt5', 'PySide2'):
- _matplotlib_use('Qt5Agg', force=False)
- from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
+deprecated_warning(type_='module',
+ name=__file__,
+ replacement='silx.gui.utils.matplotlib',
+ since_version='0.14.0')
-else:
- raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
+from silx.gui.utils.matplotlib import FigureCanvasQTAgg # noqa
diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py
index ad61536..755b185 100644
--- a/silx/gui/plot/stats/stats.py
+++ b/silx/gui/plot/stats/stats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -22,7 +22,9 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""This module provides the :class:`Scatter` item of the :class:`Plot`.
+"""This module provides mechanism relative to stats calculation within a
+:class:`PlotWidget`.
+It also include the implementation of the statistics themselves.
"""
__authors__ = ["H. Payno"]
@@ -31,13 +33,19 @@ __date__ = "06/06/2018"
from collections import OrderedDict
+from functools import lru_cache
import logging
import numpy
+import numpy.ma
from .. import items
-from ....math.combo import min_max
+from ..CurvesROIWidget import ROI
+from ..items.roi import RegionOfInterest
+from ....math.combo import min_max
+from silx.utils.proxy import docstring
+from ....utils.deprecation import deprecated
logger = logging.getLogger(__name__)
@@ -60,7 +68,8 @@ class Stats(OrderedDict):
for stat in _statslist:
self.add(stat)
- def calculate(self, item, plot, onlimits):
+ def calculate(self, item, plot, onlimits, roi, data_changed=False,
+ roi_changed=False):
"""
Call all :class:`Stat` object registered and return the result of the
computation.
@@ -69,38 +78,31 @@ class Stats(OrderedDict):
:param plot: plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: region of interest for statistic calculation. Incompatible
+ with the `onlimits` option.
+ :type roi: Union[None, :class:`~_RegionOfInterestBase`]
+ :param bool data_changed: did the data changed since last calculation.
+ :param bool roi_changed: did the associated roi (if any) has changed
+ since last calculation.
:return dict: dictionary with :class:`Stat` name as ket and result
of the calculation as value
"""
- context = None
- # Check for PlotWidget items
- if isinstance(item, items.Curve):
- context = _CurveContext(item, plot, onlimits)
- elif isinstance(item, items.ImageData):
- context = _ImageContext(item, plot, onlimits)
- elif isinstance(item, items.Scatter):
- context = _ScatterContext(item, plot, onlimits)
- elif isinstance(item, items.Histogram):
- context = _HistogramContext(item, plot, onlimits)
- else:
- # Check for SceneWidget items
- from ...plot3d import items as items3d # Lazy import
-
- if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
- context = _plot3DScatterContext(item, plot, onlimits)
- elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)):
- context = _plot3DArrayContext(item, plot, onlimits)
-
- if context is None:
- raise ValueError('Item type not managed')
-
res = {}
+ context = self._getContext(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
for statName, stat in list(self.items()):
if context.kind not in stat.compatibleKinds:
logger.debug('kind %s not managed by statistic %s'
% (context.kind, stat.name))
res[statName] = None
else:
+ if roi_changed is True:
+ context.clear_mask()
+ if data_changed is True or roi_changed is True:
+ # if data changed or mask changed
+ context.clipData(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ # init roi and data
res[statName] = stat.calculate(context)
return res
@@ -109,8 +111,40 @@ class Stats(OrderedDict):
OrderedDict.__setitem__(self, key, value)
def add(self, stat):
+ """Add a :class:`Stat` to the set
+
+ :param Stat stat: stat to add to the set
+ """
self.__setitem__(key=stat.name, value=stat)
+ @staticmethod
+ @lru_cache(maxsize=50)
+ def _getContext(item, plot, onlimits, roi):
+ context = None
+ # Check for PlotWidget items
+ if isinstance(item, items.Curve):
+ context = _CurveContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.ImageData):
+ context = _ImageContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Scatter):
+ context = _ScatterContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Histogram):
+ context = _HistogramContext(item, plot, onlimits, roi=roi)
+ else:
+ # Check for SceneWidget items
+ from ...plot3d import items as items3d # Lazy import
+
+ if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
+ context = _plot3DScatterContext(item, plot, onlimits,
+ roi=roi)
+ elif isinstance(item,
+ (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits,
+ roi=roi)
+ if context is None:
+ raise ValueError('Item type not managed')
+ return context
+
class _StatsContext(object):
"""
@@ -127,8 +161,11 @@ class _StatsContext(object):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
"""
- def __init__(self, item, kind, plot, onlimits):
+ def __init__(self, item, kind, plot, onlimits, roi):
assert item
assert plot
assert type(onlimits) is bool
@@ -136,9 +173,12 @@ class _StatsContext(object):
self.min = None
self.max = None
self.data = None
+ self.roi = None
+ self.onlimits = onlimits
self.values = None
- """The array of data"""
+ """The array of data with limit filtering if any. Is a numpy.ma.array,
+ meaning that it embed the mask applied by the roi if any"""
self.axes = None
"""A list of array of position on each axis.
@@ -151,11 +191,69 @@ class _StatsContext(object):
and the order is (x, y, z).
"""
- self.createContext(item, plot, onlimits)
+ self.clipData(item, plot, onlimits, roi=roi)
+
+ def clipData(self, item, plot, onlimits, roi):
+ """
+ Clip the data to the current mask to have accurate statistics
+
+ :param item: item for whiwh we want to clip data
+ :param plot: plot containing the item
+ :param onlimits: do we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+ raise NotImplementedError()
- def createContext(self, item, plot, onlimits):
+ def clear_mask(self):
+ """
+ Remove the mask to force recomputation of it on next iteration
+ :return:
+ """
+ raise NotImplementedError()
+
+ @property
+ def mask(self):
+ if self.values is not None:
+ assert isinstance(self.values, numpy.ma.MaskedArray)
+ return self.values.mask
+ else:
+ return None
+
+ @property
+ def is_mask_valid(self, **kwargs):
+ """Return if the mask is valid for the data or need to be recomputed"""
+ raise NotImplementedError("Base class")
+
+ def _set_mask_validity(self, **kwargs):
+ """User to set some values that allows to define the mask properties
+ and boundaries"""
raise NotImplementedError("Base class")
+ def clipData(self, item, plot, onlimits, roi):
+ """
+ Function called before computing each statistics associated to this
+ context. It will insure the context for the (item, plot, onlimits, roi)
+ is created.
+
+ :param item: item for which we want statistics
+ :param plot: plot containing the statistics
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+ raise NotImplementedError("Base class")
+
+ @deprecated(reason="context are now stored and keep during stats life."
+ "So this function will be called only once",
+ replacement="clipData", since_version="0.13.0")
+ def createContext(self, item, plot, onlimits, roi):
+ return self.clipData(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+
def isStructuredData(self):
"""Returns True if data as an array-like structure.
@@ -184,8 +282,34 @@ class _StatsContext(object):
else:
return self.values.ndim == 1
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ if roi is not None and onlimits is True:
+ raise ValueError('Stats context is unable to manage both a ROI'
+ 'and the `onlimits` option')
+
+
+class _ScatterCurveHistoMixInContext(_StatsContext):
+ def __init__(self, kind, item, plot, onlimits, roi):
+ self.clear_mask()
+ _StatsContext.__init__(self, item=item, kind=kind,
+ plot=plot, onlimits=onlimits, roi=roi)
-class _CurveContext(_StatsContext):
+ def _set_mask_validity(self, onlimits, from_, to_):
+ self._onlimits = onlimits
+ self._from_ = from_
+ self._to_ = to_
+
+ def clear_mask(self):
+ self._onlimits = None
+ self._from_ = None
+ self._to_ = None
+
+ def is_mask_valid(self, onlimits, from_, to_):
+ return (onlimits == self.onlimits and from_ == self._from_ and
+ to_ == self._to_)
+
+
+class _CurveContext(_ScatterCurveHistoMixInContext):
"""
StatsContext for :class:`Curve`
@@ -193,32 +317,63 @@ class _CurveContext(_StatsContext):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
- _StatsContext.__init__(self, kind='curve', item=item,
- plot=plot, onlimits=onlimits)
-
- def createContext(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='curve', item=item,
+ plot=plot, onlimits=onlimits,
+ roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ self.roi = roi
+ self.onlimits = onlimits
xData, yData = item.getData(copy=True)[0:2]
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- mask = (minX <= xData) & (xData <= maxX)
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
yData = yData[mask]
xData = xData[mask]
+ mask = numpy.zeros_like(yData)
+ elif roi:
+ minX, maxX = roi.getFrom(), roi.getTo()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ mask = mask.astype(numpy.int32)
+ else:
+ mask = numpy.zeros_like(yData)
self.xData = xData
self.yData = yData
- if len(yData) > 0:
- self.min, self.max = min_max(yData)
+ self.values = numpy.ma.array(yData, mask=mask)
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
else:
self.min, self.max = None, None
self.data = (xData, yData)
- self.values = yData
+
self.axes = (xData,)
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
-class _HistogramContext(_StatsContext):
+
+class _HistogramContext(_ScatterCurveHistoMixInContext):
"""
StatsContext for :class:`Histogram`
@@ -226,32 +381,66 @@ class _HistogramContext(_StatsContext):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
- _StatsContext.__init__(self, kind='histogram', item=item,
- plot=plot, onlimits=onlimits)
-
- def createContext(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='histogram',
+ item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
yData, edges = item.getData(copy=True)[0:2]
xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
+
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- mask = (minX <= xData) & (xData <= maxX)
+ if self.is_mask_valid(onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ self._set_mask_validity(onlimits=True, from_=minX, to_=maxX)
+ elif roi:
+ if self.is_mask_valid(onlimits, from_=roi._fromdata, to_=roi._todata):
+ mask = self.mask
+ else:
+ mask = (roi._fromdata <= xData) & (xData <= roi._todata)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=True, from_=roi._fromdata,
+ to_=roi._todata)
+ else:
+ mask = numpy.zeros_like(self.data)
+
+ if onlimits:
yData = yData[mask]
xData = xData[mask]
+ self.data = (xData, yData)
+ self.values = numpy.ma.array(yData, mask=mask)
+ self.axes = (xData,)
+
self.xData = xData
self.yData = yData
- if len(yData) > 0:
- self.min, self.max = min_max(yData)
+
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
else:
self.min, self.max = None, None
- self.data = (xData, yData)
- self.values = yData
- self.axes = (xData,)
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
-class _ScatterContext(_StatsContext):
+
+class _ScatterContext(_ScatterCurveHistoMixInContext):
"""StatsContext scatter plots.
It supports :class:`~silx.gui.plot.items.Scatter`.
@@ -260,12 +449,19 @@ class _ScatterContext(_StatsContext):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
- _StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
- onlimits=onlimits)
-
- def createContext(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='scatter',
+ item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_ScatterCurveHistoMixInContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
valueData = item.getValueData(copy=True)
xData = item.getXData(copy=True)
yData = item.getYData(copy=True)
@@ -283,34 +479,89 @@ class _ScatterContext(_StatsContext):
xData = xData[(minY <= yData) & (yData <= maxY)]
yData = yData[(minY <= yData) & (yData <= maxY)]
- if len(valueData) > 0:
- self.min, self.max = min_max(valueData)
+ if roi:
+ if self.is_mask_valid(onlimits=onlimits, from_=roi.getFrom(),
+ to_=roi.getTo()):
+ mask = self.mask
+ else:
+ mask = (xData < roi.getFrom()) | (xData > roi.getTo())
else:
- self.min, self.max = None, None
+ mask = numpy.zeros_like(xData)
+
self.data = (xData, yData, valueData)
- self.values = valueData
+ self.values = numpy.ma.array(valueData, mask=mask)
self.axes = (xData, yData)
+ unmasked_values = self.values.compressed()
+ if len(unmasked_values) > 0:
+ self.min, self.max = min_max(unmasked_values)
+ else:
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
+
class _ImageContext(_StatsContext):
"""StatsContext for images.
It supports :class:`~silx.gui.plot.items.ImageData`.
+ :warning: behaviour of scale images: now the statistics are computed on
+ the entire data array (there is no sampling in the array or
+ interpolation regarding the scale).
+ This also mean that the result can differ from what is displayed.
+ But I guess there is no perfect behaviour.
+
+ :warning: `isIn` functions for image context: for now have basically a
+ binary approach, the pixel is in a roi or not. To have a fully
+ 'correct behaviour' we should add a weight on stats calculation
+ to moderate the pixel value.
+
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
+ self.clear_mask()
_StatsContext.__init__(self, kind='image', item=item,
- plot=plot, onlimits=onlimits)
-
- def createContext(self, item, plot, onlimits):
+ plot=plot, onlimits=onlimits, roi=roi)
+
+ def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax
+ : float):
+ self._mask_x_min = xmin
+ self._mask_x_max = xmax
+ self._mask_y_min = ymin
+ self._mask_y_max = ymax
+
+ def clear_mask(self):
+ self._mask_x_min = None
+ self._mask_x_max = None
+ self._mask_y_min = None
+ self._mask_y_max = None
+
+ def is_mask_valid(self, xmin, xmax, ymin, ymax):
+ return (xmin == self._mask_x_min and xmax == self._mask_x_max and
+ ymin == self._mask_y_min and ymax == self._mask_y_max)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
self.origin = item.getOrigin()
self.scale = item.getScale()
self.data = item.getData(copy=True)
+ mask = numpy.zeros_like(self.data)
+ """mask use to know of the stat should be count in or not"""
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
@@ -324,21 +575,50 @@ class _ImageContext(_StatsContext):
XMinBound = max(XMinBound, 0)
YMinBound = max(YMinBound, 0)
+ if onlimits:
if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
self.data = None
else:
self.data = self.data[YMinBound:YMaxBound + 1,
XMinBound:XMaxBound + 1]
- if self.data.size > 0:
- self.min, self.max = min_max(self.data)
+ mask = numpy.zeros_like(self.data)
+ elif roi:
+ minX, maxX = 0, self.data.shape[1]
+ minY, maxY = 0, self.data.shape[0]
+
+ XMinBound = max(minX, 0)
+ YMinBound = max(minY, 0)
+ XMaxBound = min(maxX, self.data.shape[1])
+ YMaxBound = min(maxY, self.data.shape[0])
+
+ if self.is_mask_valid(xmin=XMinBound, xmax=XMaxBound,
+ ymin=YMinBound, ymax=YMaxBound):
+ mask = self.mask
+ else:
+ for x in range(XMinBound, XMaxBound):
+ for y in range(YMinBound, YMaxBound):
+ _x = (x * self.scale[0]) + self.origin[0]
+ _y = (y * self.scale[1]) + self.origin[1]
+ mask[y, x] = not roi.contains((_x, _y))
+ self._set_mask_validity(xmin=XMinBound, xmax=XMaxBound,
+ ymin=YMinBound, ymax=YMaxBound)
+ self.values = numpy.ma.array(self.data, mask=mask)
+ if self.values.compressed().size > 0:
+ self.min, self.max = min_max(self.values.compressed())
else:
self.min, self.max = None, None
- self.values = self.data
if self.values is not None:
self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]))
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
class _plot3DScatterContext(_StatsContext):
"""StatsContext for 3D scatter plots.
@@ -350,16 +630,26 @@ class _plot3DScatterContext(_StatsContext):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
_StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
- onlimits=onlimits)
+ onlimits=onlimits, roi=roi)
- def createContext(self, item, plot, onlimits):
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
if onlimits:
raise RuntimeError("Unsupported plot %s" % str(plot))
-
values = item.getValueData(copy=False)
+ if roi:
+ logger.warning("Roi are unsupported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
if values is not None and len(values) > 0:
self.values = values
@@ -367,13 +657,20 @@ class _plot3DScatterContext(_StatsContext):
if self.values.ndim == 3:
axes.append(item.getZData(copy=False))
self.axes = tuple(axes)
-
self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
else:
self.values = None
self.axes = None
self.min, self.max = None, None
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
class _plot3DArrayContext(_StatsContext):
"""StatsContext for 3D scalar field and data image.
@@ -385,26 +682,45 @@ class _plot3DArrayContext(_StatsContext):
:param plot: the plot containing the item
:param bool onlimits: True if we want to apply statistic only on
visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
"""
- def __init__(self, item, plot, onlimits):
+ def __init__(self, item, plot, onlimits, roi):
_StatsContext.__init__(self, kind='image', item=item, plot=plot,
- onlimits=onlimits)
+ onlimits=onlimits, roi=roi)
- def createContext(self, item, plot, onlimits):
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
if onlimits:
raise RuntimeError("Unsupported plot %s" % str(plot))
values = item.getData(copy=False)
+ if roi:
+ logger.warning("Roi are unsuported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
if values is not None and len(values) > 0:
self.values = values
self.axes = tuple([numpy.arange(size) for size in self.values.shape])
self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
else:
self.values = None
self.axes = None
self.min, self.max = None, None
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram'
@@ -456,6 +772,7 @@ class Stat(StatBase):
StatBase.__init__(self, name, kinds)
self._fct = fct
+ @docstring(StatBase)
def calculate(self, context):
if context.values is not None:
if context.kind in self.compatibleKinds:
@@ -472,6 +789,7 @@ class StatMin(StatBase):
def __init__(self):
StatBase.__init__(self, name='min')
+ @docstring(StatBase)
def calculate(self, context):
return context.min
@@ -481,6 +799,7 @@ class StatMax(StatBase):
def __init__(self):
StatBase.__init__(self, name='max')
+ @docstring(StatBase)
def calculate(self, context):
return context.max
@@ -490,6 +809,7 @@ class StatDelta(StatBase):
def __init__(self):
StatBase.__init__(self, name='delta')
+ @docstring(StatBase)
def calculate(self, context):
return context.max - context.min
@@ -506,14 +826,17 @@ class _StatCoord(StatBase):
:param int index: Index in the flattened data array
:rtype: List[int]
"""
- if context.isStructuredData():
+
+ axes = context.axes
+
+ if context.isStructuredData() or context.roi:
coordinates = []
- for axis in reversed(context.axes):
+ for axis in reversed(axes):
coordinates.append(axis[index % len(axis)])
index = index // len(axis)
return tuple(coordinates)
else:
- return tuple(axis[index] for axis in context.axes)
+ return tuple(axis[index] for axis in axes)
class StatCoordMin(_StatCoord):
@@ -521,13 +844,15 @@ class StatCoordMin(_StatCoord):
def __init__(self):
_StatCoord.__init__(self, name='coords min')
+ @docstring(StatBase)
def calculate(self, context):
if context.values is None or not context.isScalarData():
return None
- index = numpy.argmin(context.values)
+ index = context.values.argmin()
return self._indexToCoordinates(context, index)
+ @docstring(StatBase)
def getToolTip(self, kind):
return "Coordinates of the first minimum value of the data"
@@ -537,13 +862,17 @@ class StatCoordMax(_StatCoord):
def __init__(self):
_StatCoord.__init__(self, name='coords max')
+ @docstring(StatBase)
def calculate(self, context):
if context.values is None or not context.isScalarData():
return None
- index = numpy.argmax(context.values)
+ # TODO: the values should be a mask array by default, will be simpler
+ # if possible
+ index = context.values.argmax()
return self._indexToCoordinates(context, index)
+ @docstring(StatBase)
def getToolTip(self, kind):
return "Coordinates of the first maximum value of the data"
@@ -553,11 +882,12 @@ class StatCOM(StatBase):
def __init__(self):
StatBase.__init__(self, name='COM', description='Center of mass')
+ @docstring(StatBase)
def calculate(self, context):
if context.values is None or not context.isScalarData():
return None
- values = numpy.array(context.values, dtype=numpy.float64)
+ values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64)
sum_ = numpy.sum(values)
if sum_ == 0.:
return (numpy.nan,) * len(context.axes)
@@ -573,5 +903,6 @@ class StatCOM(StatBase):
return tuple(
numpy.sum(axis * values) / sum_ for axis in context.axes)
+ @docstring(StatBase)
def getToolTip(self, kind):
return "Compute the center of mass of the dataset"
diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py
index f69daff..17578d8 100644
--- a/silx/gui/plot/stats/statshandler.py
+++ b/silx/gui/plot/stats/statshandler.py
@@ -22,7 +22,8 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""
+"""This module containts the classes relative to the management of statistics
+display.
"""
__authors__ = ["H. Payno"]
@@ -178,7 +179,8 @@ class StatsHandler(object):
else:
return self.formatters[name].format(val)
- def calculate(self, item, plot, onlimits):
+ def calculate(self, item, plot, onlimits, roi=None, data_changed=False,
+ roi_changed=False):
"""
compute all statistic registered and return the list of formatted
statistics result.
@@ -187,10 +189,14 @@ class StatsHandler(object):
:param plot: plot containing the item
:param onlimits: True if we want to compute statistics on visible data
only
+ :type: bool
+ :param roi: region of interest for statistic calculation
+ :type: Union[None,:class:`_RegionOfInterestBase`]
:return: list of formatted statistics (as str)
:rtype: dict
"""
- res = self.stats.calculate(item, plot, onlimits)
+ res = self.stats.calculate(item, plot, onlimits, roi,
+ data_changed=data_changed, roi_changed=roi_changed)
for resName, resValue in list(res.items()):
res[resName] = self.format(resName, res[resName])
return res
diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py
index 0477e2a..dfb7c2e 100644
--- a/silx/gui/plot/test/__init__.py
+++ b/silx/gui/plot/test/__init__.py
@@ -53,6 +53,7 @@ from . import testSaveAction
from . import testScatterView
from . import testPixelIntensityHistoAction
from . import testCompareImages
+from . import testRoiStatsWidget
def suite():
@@ -86,5 +87,6 @@ def suite():
testScatterView.suite(),
testPixelIntensityHistoAction.suite(),
testCompareImages.suite(),
+ testRoiStatsWidget.suite(),
])
return test_suite
diff --git a/silx/gui/plot/test/testComplexImageView.py b/silx/gui/plot/test/testComplexImageView.py
index 051ec4d..4ac3488 100644
--- a/silx/gui/plot/test/testComplexImageView.py
+++ b/silx/gui/plot/test/testComplexImageView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -50,7 +50,7 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
def testPlot2DComplex(self):
"""Test API of ComplexImageView widget"""
- data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex)
+ data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64)
self.plot.setData(data)
self.plot.setKeepDataAspectRatio(True)
self.plot.getPlot().resetZoom()
@@ -76,11 +76,11 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
self.qWait(100)
# Test no data
- self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex))
+ self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64))
self.qWait(100)
# Test float data
- self.plot.setData(numpy.arange(100, dtype=numpy.float).reshape(10, 10))
+ self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10))
self.qWait(100)
diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py
index 77c53a8..6a0ab8c 100644
--- a/silx/gui/plot/test/testCurvesROIWidget.py
+++ b/silx/gui/plot/test/testCurvesROIWidget.py
@@ -375,13 +375,13 @@ class TestRoiWidgetSignals(TestCaseQt):
self.listener.clear()
roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.registerROI(roi1)
self.assertEqual(self.listener.callCount(), 1)
self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
self.listener.clear()
roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5)
- self.curves_roi_widget.roiTable.addRoi(roi2)
+ self.curves_roi_widget.roiTable.registerROI(roi2)
self.assertEqual(self.listener.callCount(), 1)
self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2')
self.listener.clear()
@@ -398,7 +398,7 @@ class TestRoiWidgetSignals(TestCaseQt):
self.assertTrue(self.listener.arguments()[0][0]['current'] is None)
self.listener.clear()
- self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.registerROI(roi1)
self.assertEqual(self.listener.callCount(), 1)
self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
@@ -415,7 +415,7 @@ class TestRoiWidgetSignals(TestCaseQt):
"""Test SigROISignal when modifying it"""
self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
- self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.registerROI(roi1)
self.curves_roi_widget.roiTable.setActiveRoi(roi1)
# test modify the roi2 object
@@ -450,7 +450,7 @@ class TestRoiWidgetSignals(TestCaseQt):
def testSetActiveCurve(self):
"""Test sigRoiSignal when set an active curve"""
roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
- self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.registerROI(roi1)
self.curves_roi_widget.roiTable.setActiveRoi(roi1)
self.listener.clear()
self.plot.setActiveCurve('curve0')
diff --git a/silx/gui/plot/test/testItem.py b/silx/gui/plot/test/testItem.py
index ad739a2..8dacdea 100644
--- a/silx/gui/plot/test/testItem.py
+++ b/silx/gui/plot/test/testItem.py
@@ -35,6 +35,7 @@ import numpy
from silx.gui.utils.testutils import SignalListener
from silx.gui.plot.items import ItemChangedType
+from silx.gui.plot import items
from .utils import PlotWidgetTestCase
@@ -242,11 +243,96 @@ class TestSymbol(PlotWidgetTestCase):
self.assertEqual('Diamond', name)
+class TestVisibleExtent(PlotWidgetTestCase):
+ """Test item's visible extent feature"""
+
+ def testGetVisibleBounds(self):
+ """Test Item.getVisibleBounds"""
+
+ # Create test items (with a bounding box of x: [1,3], y: [0,2])
+ curve = items.Curve()
+ curve.setData((1, 2, 3), (0, 1, 2))
+
+ histogram = items.Histogram()
+ histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3))
+
+ image = items.ImageData()
+ image.setOrigin((1, 0))
+ image.setData(numpy.arange(4).reshape(2, 2))
+
+ scatter = items.Scatter()
+ scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3))
+
+ bbox = items.BoundingRect()
+ bbox.setBounds((1, 3, 0, 2))
+
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for item in (curve, histogram, image, scatter, bbox):
+ with self.subTest(item=item):
+ xaxis.setLimits(0, 100)
+ yaxis.setLimits(0, 100)
+ self.plot.addItem(item)
+ self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.))
+
+ xaxis.setLimits(0.5, 2.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.))
+
+ yaxis.setLimits(0.5, 1.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
+
+ item.setVisible(False)
+ self.assertIsNone(item.getVisibleBounds())
+
+ self.plot.clear()
+
+ def testVisibleExtentTracking(self):
+ """Test Item's visible extent tracking"""
+ image = items.ImageData()
+ image.setData(numpy.arange(6).reshape(2, 3))
+
+ listener = SignalListener()
+ image._sigVisibleBoundsChanged.connect(listener)
+ image._setVisibleBoundsTracking(True)
+ self.assertTrue(image._isVisibleBoundsTracking())
+
+ self.plot.addItem(image)
+ self.assertEqual(listener.callCount(), 1)
+
+ self.plot.getXAxis().setLimits(0, 1)
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.hide()
+ self.qapp.processEvents()
+ # No event here
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.getXAxis().setLimits(1, 2)
+ # No event since PlotWidget is hidden, delayed to PlotWidget show
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.show()
+ self.qapp.processEvents()
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 3)
+
+ image.setOrigin((-1, -1))
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(False)
+ image.setOrigin((0, 0))
+ # No event since item is not visible
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(True)
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 5)
+
+
def suite():
test_suite = unittest.TestSuite()
loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestSigItemChangedSignal))
- test_suite.addTest(loadTests(TestSymbol))
+ for klass in (TestSigItemChangedSignal, TestSymbol, TestVisibleExtent):
+ test_suite.addTest(loadTests(klass))
return test_suite
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
index a05c1be..2e8db55 100644
--- a/silx/gui/plot/test/testMaskToolsWidget.py
+++ b/silx/gui/plot/test/testMaskToolsWidget.py
@@ -84,10 +84,15 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.mouseMove(plot, pos=(0, 0))
self.mouseMove(plot, pos=pos0)
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos0)
- self.mouseMove(plot, pos=(0, 0))
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
self.mouseMove(plot, pos=pos1)
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
def _drawPolygon(self):
"""Draw a star polygon in the plot"""
@@ -106,7 +111,9 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
for pos in star:
self.mouseMove(plot, pos=pos)
self.qapp.processEvents()
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
self.qapp.processEvents()
def _drawPencil(self):
diff --git a/silx/gui/plot/test/testPlotInteraction.py b/silx/gui/plot/test/testPlotInteraction.py
index 335b1e4..7a30434 100644
--- a/silx/gui/plot/test/testPlotInteraction.py
+++ b/silx/gui/plot/test/testPlotInteraction.py
@@ -68,7 +68,11 @@ class TestSelectPolygon(PlotWidgetTestCase):
for pos in polygon:
self.mouseMove(plot, pos=pos)
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
self.plot.sigPlotSignal.disconnect(dump)
return [args[0] for args in dump.received]
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
index 4ef6a72..f9d2281 100755
--- a/silx/gui/plot/test/testPlotWidget.py
+++ b/silx/gui/plot/test/testPlotWidget.py
@@ -43,7 +43,7 @@ from silx.test.utils import test_options
from silx.gui import qt
from silx.gui.plot import PlotWidget
from silx.gui.plot.items.curve import CurveStyle
-from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent
+from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
from silx.gui.colors import Colormap
from .utils import PlotWidgetTestCase
@@ -326,6 +326,23 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
resetzoom=False)
self.plot.resetZoom()
+ def testPlotColormapNaNColor(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Colormap with NaN color')
+
+ colormap = Colormap()
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+ data = DATA_2D.astype(numpy.float32)
+ data[len(data)//2:] = numpy.nan
+ self.plot.addImage(data, legend="image 1", colormap=colormap,
+ resetzoom=False)
+ self.plot.resetZoom()
+
+ colormap.setNaNColor((0., 1., 0., 1.))
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
+ self.qapp.processEvents()
+
def testImageOriginScale(self):
"""Test of image with different origin and scale"""
self.plot.setGraphTitle('origin and scale')
@@ -401,7 +418,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
def testPlotBooleanImage(self):
"""Test that a boolean image is displayed and converted to int8."""
- data = numpy.zeros((10, 10), dtype=numpy.bool)
+ data = numpy.zeros((10, 10), dtype=bool)
data[::2, ::2] = True
self.plot.addImage(data, legend='boolean')
@@ -438,6 +455,21 @@ class TestPlotCurve(PlotWidgetTestCase):
self.plot.setActiveCurveHandling(False)
+ def testPlotCurveInfinite(self):
+ """Test plot curves with not finite data"""
+ tests = {
+ 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
+ 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
+ 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]),
+ 'y some inf': ([0, 1, 2], [0, numpy.inf, 2])
+ }
+ for name, args in tests.items():
+ with self.subTest(name):
+ self.plot.addCurve(*args)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.plot.clear()
+
def testPlotCurveColorFloat(self):
color = numpy.array(numpy.random.random(3 * 1000),
dtype=numpy.float32).reshape(1000, 3)
@@ -799,17 +831,25 @@ class TestPlotItem(PlotWidgetTestCase):
"""Basic tests for addItem."""
# Polygon coordinates and color
- polygons = [ # legend, x coords, y coords, color
+ POLYGONS = [ # legend, x coords, y coords, color
('triangle', numpy.array((10, 30, 50)),
numpy.array((55, 70, 55)), 'red'),
('square', numpy.array((10, 10, 50, 50)),
numpy.array((10, 50, 50, 10)), 'green'),
('star', numpy.array((60, 70, 80, 60, 80)),
numpy.array((25, 50, 25, 40, 40)), 'blue'),
+ ('2 triangles-simple',
+ numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)),
+ numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)),
+ 'pink'),
+ ('2 triangles-extra NaN',
+ numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)),
+ numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)),
+ 'black'),
]
# Rectangle coordinantes and color
- rectangles = [ # legend, x coords, y coords, color
+ RECTANGLES = [ # legend, x coords, y coords, color
('square 1', numpy.array((1., 10.)),
numpy.array((1., 10.)), 'red'),
('square 2', numpy.array((10., 20.)),
@@ -822,6 +862,8 @@ class TestPlotItem(PlotWidgetTestCase):
numpy.array((45., 45.)), 'darkRed'),
]
+ SCALES = Axis.LINEAR, Axis.LOGARITHMIC
+
def setUp(self):
super(TestPlotItem, self).setUp()
@@ -833,40 +875,60 @@ class TestPlotItem(PlotWidgetTestCase):
self.plot.setLimits(0., 100., -100., 100.)
def testPlotItemPolygonFill(self):
- self.plot.setGraphTitle('Item Fill')
-
- for legend, xList, yList, color in self.polygons:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="polygon", fill=True, color=color)
- self.plot.resetZoom()
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=True, color=color)
+ self.plot.resetZoom()
def testPlotItemPolygonNoFill(self):
- self.plot.setGraphTitle('Item No Fill')
-
- for legend, xList, yList, color in self.polygons:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="polygon", fill=False, color=color)
- self.plot.resetZoom()
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=False, color=color)
+ self.plot.resetZoom()
def testPlotItemRectangleFill(self):
- self.plot.setGraphTitle('Rectangle Fill')
-
- for legend, xList, yList, color in self.rectangles:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=True, color=color)
- self.plot.resetZoom()
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=True, color=color)
+ self.plot.resetZoom()
def testPlotItemRectangleNoFill(self):
- self.plot.setGraphTitle('Rectangle No Fill')
-
- for legend, xList, yList, color in self.rectangles:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=False, color=color)
- self.plot.resetZoom()
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=False, color=color)
+ self.plot.resetZoom()
class TestPlotActiveCurveImage(PlotWidgetTestCase):
@@ -1384,6 +1446,20 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
"""Test coverage on setAxesDisplayed(True)"""
self.plot.setAxesDisplayed(True)
+ def testAxesMargins(self):
+ """Test PlotWidget's getAxesMargins and setAxesMargins"""
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ margins = self.plot.getAxesMargins()
+ self.assertEqual(margins, (.15, .1, .1, .15))
+
+ for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)):
+ with self.subTest(margins=margins):
+ self.plot.setAxesMargins(*margins)
+ self.qapp.processEvents()
+ self.assertEqual(self.plot.getAxesMargins(), margins)
+
def testBoundingRectItem(self):
item = BoundingRect()
item.setBounds((-1000, 1000, -2000, 2000))
@@ -1752,80 +1828,33 @@ class TestPlotMarkerLog(PlotWidgetTestCase):
self.plot.resetZoom()
-class TestPlotItemLog(PlotWidgetTestCase):
- """Basic tests for items with log scale axes"""
+class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
+ """Test [get|set]Backend to switch backend"""
- # Polygon coordinates and color
- polygons = [ # legend, x coords, y coords, color
- ('triangle', numpy.array((10, 30, 50)),
- numpy.array((55, 70, 55)), 'red'),
- ('square', numpy.array((10, 10, 50, 50)),
- numpy.array((10, 50, 50, 10)), 'green'),
- ('star', numpy.array((60, 70, 80, 60, 80)),
- numpy.array((25, 50, 25, 40, 40)), 'blue'),
- ]
-
- # Rectangle coordinantes and color
- rectangles = [ # legend, x coords, y coords, color
- ('square 1', numpy.array((1., 10.)),
- numpy.array((1., 10.)), 'red'),
- ('square 2', numpy.array((10., 20.)),
- numpy.array((10., 20.)), 'green'),
- ('square 3', numpy.array((20., 30.)),
- numpy.array((20., 30.)), 'blue'),
- ('rect 1', numpy.array((1., 30.)),
- numpy.array((35., 40.)), 'black'),
- ('line h', numpy.array((1., 30.)),
- numpy.array((45., 45.)), 'darkRed'),
- ]
-
- def setUp(self):
- super(TestPlotItemLog, self).setUp()
+ def testSwitchBackend(self):
+ """Test switching a plot with a few items"""
+ backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'}
+ if test_options.WITH_GL_TEST:
+ backends['gl'] = 'BackendOpenGL'
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
- self.plot.getXAxis().setAutoScale(False)
- self.plot.getYAxis().setAutoScale(False)
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(1., 100., 1., 100.)
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
-
- def testPlotItemPolygonLogFill(self):
- self.plot.setGraphTitle('Item Fill Log')
-
- for legend, xList, yList, color in self.polygons:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="polygon", fill=True, color=color)
- self.plot.resetZoom()
-
- def testPlotItemPolygonLogNoFill(self):
- self.plot.setGraphTitle('Item No Fill Log')
-
- for legend, xList, yList, color in self.polygons:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="polygon", fill=False, color=color)
- self.plot.resetZoom()
-
- def testPlotItemRectangleLogFill(self):
- self.plot.setGraphTitle('Rectangle Fill Log')
-
- for legend, xList, yList, color in self.rectangles:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=True, color=color)
+ self.plot.addImage(numpy.arange(100).reshape(10, 10))
+ self.plot.addCurve((-3, -2, -1), (1, 2, 3))
self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 2)
- def testPlotItemRectangleLogNoFill(self):
- self.plot.setGraphTitle('Rectangle No Fill Log')
+ for backend, className in backends.items():
+ with self.subTest(backend=backend):
+ self.plot.setBackend(backend)
+ self.plot.replot()
- for legend, xList, yList, color in self.rectangles:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=False, color=color)
- self.plot.resetZoom()
+ retrievedBackend = self.plot.getBackend()
+ self.assertEqual(type(retrievedBackend).__name__, className)
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(self.plot.getItems(), items)
def suite():
@@ -1841,8 +1870,7 @@ def suite():
TestPlotEmptyLog,
TestPlotCurveLog,
TestPlotImageLog,
- TestPlotMarkerLog,
- TestPlotItemLog)
+ TestPlotMarkerLog)
test_suite = unittest.TestSuite()
@@ -1859,6 +1887,9 @@ def suite():
for testClass in testClasses:
test_suite.addTest(parameterize(testClass, backend='gl'))
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
+ TestPlotWidgetSwitchBackend))
+
return test_suite
diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py
index 8e7b35c..e12b756 100644
--- a/silx/gui/plot/test/testPlotWindow.py
+++ b/silx/gui/plot/test/testPlotWindow.py
@@ -33,12 +33,12 @@ import unittest
import numpy
from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+from silx.test.utils import test_options
from silx.gui import qt
from silx.gui.plot import PlotWindow
from silx.gui.colors import Colormap
-
class TestPlotWindow(TestCaseQt):
"""Base class for tests of PlotWindow."""
@@ -155,6 +155,25 @@ class TestPlotWindow(TestCaseQt):
self.assertEqual(self._count, 1)
del self._count
+ @unittest.skipUnless(test_options.WITH_GL_TEST,
+ test_options.WITH_QT_TEST_REASON)
+ def testSwitchBackend(self):
+ """Test switching an empty plot"""
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
+
+ for backend in ('gl', 'mpl'):
+ with self.subTest():
+ self.plot.setBackend(backend)
+ self.plot.replot()
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(
+ self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
+
+
def suite():
test_suite = unittest.TestSuite()
test_suite.addTest(
diff --git a/silx/gui/plot/test/testRoiStatsWidget.py b/silx/gui/plot/test/testRoiStatsWidget.py
new file mode 100644
index 0000000..378d499
--- /dev/null
+++ b/silx/gui/plot/test/testRoiStatsWidget.py
@@ -0,0 +1,290 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ROIStatsWidget"""
+
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.StatsWidget import UpdateMode
+import unittest
+import numpy
+
+
+
+class _TestRoiStatsBase(TestCaseQt):
+ """Base class for several unittest relative to ROIStatsWidget"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ # define plot
+ self.plot = PlotWindow()
+ self.plot.addImage(numpy.arange(10000).reshape(100, 100),
+ legend='img1')
+ self.img_item = self.plot.getImage('img1')
+ self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56),
+ legend='curve1')
+ self.curve_item = self.plot.getCurve('curve1')
+ self.plot.addHistogram(edges=numpy.linspace(0, 10, 56),
+ histogram=numpy.arange(56), legend='histo1')
+ self.histogram_item = self.plot.getHistogram(legend='histo1')
+ self.plot.addScatter(x=numpy.linspace(0, 10, 56),
+ y=numpy.linspace(0, 10, 56),
+ value=numpy.arange(56),
+ legend='scatter1')
+ self.scatter_item = self.plot.getScatter(legend='scatter1')
+
+ # stats widget
+ self.statsWidget = ROIStatsWidget(plot=self.plot)
+
+ # define stats
+ stats = [
+ ('sum', numpy.sum),
+ ('mean', numpy.mean),
+ ]
+ self.statsWidget.setStats(stats=stats)
+
+ # define rois
+ self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy')
+ self.rectangle_roi = RectangleROI()
+ self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
+ self.rectangle_roi.setName('Initial ROI')
+ self.polygon_roi = PolygonROI()
+ points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
+ self.polygon_roi.setPoints(points)
+
+ def statsTable(self):
+ return self.statsWidget._statsROITable
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.statsWidget.close()
+ self.statsWidget = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+
+class TestRoiStatsCouple(_TestRoiStatsBase):
+ """
+ Test different possible couple (roi, plotItem).
+ Check that:
+
+ * computation is correct if couple is valid
+ * raise an error if couple is invalid
+ """
+ def testROICurve(self):
+ """
+ Test that the couple (ROI, curveItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testRectangleImage(self):
+ """
+ Test that the couple (RectangleROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ self.plot.addImage(numpy.ones(10000).reshape(100, 100),
+ legend='img1')
+ self.qapp.processEvents()
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), str(float(21*21)))
+ self.assertEqual(tableItems['mean'].text(), '1.0')
+
+ def testPolygonImage(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.polygon_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '22750')
+ self.assertEqual(tableItems['mean'].text(), '455.0')
+
+ def testROIImage(self):
+ """
+ Test that the couple (ROI, imageItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.img_item)
+
+ def testRectangleCurve(self):
+ """
+ Test that the couple (rectangleROI, curveItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.curve_item)
+
+ def testROIHistogram(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testROIScatter(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+
+class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
+ """Test adding and removing (roi, plotItem) items"""
+ def testAddRemoveItems(self):
+ item1 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ self.assertTrue(item1 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 1)
+ item2 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item2 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to add twice the same item
+ item3 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item3 is None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ item4 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ self.assertTrue(item4 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 3)
+
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to remove twice the same item
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ self.statsWidget.removeItem(plotItem=item2._plot_item,
+ roi=item2._roi)
+ self.statsWidget.removeItem(plotItem=item1._plot_item,
+ roi=item1._roi)
+ self.assertEqual(self.statsTable().rowCount(), 0)
+
+
+class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the roi is updated"""
+ def testChangeRoi(self):
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update roi
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.qapp.processEvents()
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+
+class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the plot item is updated"""
+ def testChangeImage(self):
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update plot
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertNotEqual(tableItems['mean'].text(), '1059.5')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertEqual(tableItems['mean'].text(), '1110.0')
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ for TestClass in (TestRoiStatsCouple, TestRoiStatsRoiUpdate,
+ TestRoiStatsPlotItemUpdate):
+ test_suite.addTest(
+ unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py
index 171ec42..800f30e 100644
--- a/silx/gui/plot/test/testScatterMaskToolsWidget.py
+++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -86,10 +86,16 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.mouseMove(plot, pos=(0, 0))
self.mouseMove(plot, pos=pos0)
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos0)
- self.mouseMove(plot, pos=(0, 0))
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
self.mouseMove(plot, pos=pos1)
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
def _drawPolygon(self):
"""Draw a star polygon in the plot"""
@@ -108,7 +114,9 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
for pos in star:
self.mouseMove(plot, pos=pos)
self.qapp.processEvents()
- self.mouseClick(plot, qt.Qt.LeftButton, pos=pos)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
self.qapp.processEvents()
def _drawPencil(self):
diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py
index 80c85d6..7605bbc 100644
--- a/silx/gui/plot/test/testStackView.py
+++ b/silx/gui/plot/test/testStackView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -60,6 +60,19 @@ class TestStackView(TestCaseQt):
del self.stackview
super(TestStackView, self).tearDown()
+ def testScaleColormapRangeToStack(self):
+ """Test scaleColormapRangeToStack"""
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ colormap = self.stackview.getColormap()
+
+ # Colormap autoscale to image
+ self.assertEqual(colormap.getVRange(), (None, None))
+ self.stackview.scaleColormapRangeToStack()
+
+ # Colormap range set according to stack range
+ self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max()))
+
def testSetStack(self):
self.stackview.setStack(self.mystack)
self.stackview.setColormap("viridis", autoscale=True)
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py
index 8db8cc9..d5046ba 100644
--- a/silx/gui/plot/test/testStats.py
+++ b/silx/gui/plot/test/testStats.py
@@ -35,6 +35,11 @@ from silx.gui.plot import StatsWidget
from silx.gui.plot.stats import statshandler
from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import Plot1D, Plot2D
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.CurvesROIWidget import ROI
from silx.utils.testutils import ParametricTestCase
import unittest
import logging
@@ -43,12 +48,9 @@ import numpy
_logger = logging.getLogger(__name__)
-class TestStats(TestCaseQt):
- """
- Test :class:`BaseClass` class and inheriting classes
- """
+class TestStatsBase(object):
+ """Base class for stats TestCase"""
def setUp(self):
- TestCaseQt.setUp(self)
self.createCurveContext()
self.createImageContext()
self.createScatterContext()
@@ -63,7 +65,6 @@ class TestStats(TestCaseQt):
self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.scatterPlot.close()
del self.scatterPlot
- TestCaseQt.tearDown(self)
def createCurveContext(self):
self.plot1d = Plot1D()
@@ -74,12 +75,13 @@ class TestStats(TestCaseQt):
self.curveContext = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=False)
+ onlimits=False,
+ roi=None)
def createScatterContext(self):
self.scatterPlot = Plot2D()
lgd = 'scatter plot'
- self.xScatterData = numpy.array([0, 1, 2, 20, 50, 60, 36])
+ self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
@@ -87,7 +89,8 @@ class TestStats(TestCaseQt):
self.scatterContext = stats._ScatterContext(
item=self.scatterPlot.getScatter(lgd),
plot=self.scatterPlot,
- onlimits=False
+ onlimits=False,
+ roi=None
)
def createImageContext(self):
@@ -99,7 +102,8 @@ class TestStats(TestCaseQt):
self.imageContext = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
- onlimits=False
+ onlimits=False,
+ roi=None
)
def getBasicStats(self):
@@ -113,6 +117,19 @@ class TestStats(TestCaseQt):
'com': stats.StatCOM()
}
+
+class TestStats(TestStatsBase, TestCaseQt):
+ """
+ Test :class:`BaseClass` class and inheriting classes
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ TestStatsBase.setUp(self)
+
+ def tearDown(self):
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
def testBasicStatsCurve(self):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
@@ -155,7 +172,8 @@ class TestStats(TestCaseQt):
image2Context = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
- onlimits=False
+ onlimits=False,
+ roi=None,
)
_stats = self.getBasicStats()
self.assertEqual(_stats['min'].calculate(image2Context), 0)
@@ -225,21 +243,24 @@ class TestStats(TestCaseQt):
curveContextOnLimits = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(curveContextOnLimits), 2)
self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
imageContextOnLimits = stats._ImageContext(
item=self.plot2d.getImage('test image'),
plot=self.plot2d,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(imageContextOnLimits), 32)
self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
scatterContextOnLimits = stats._ScatterContext(
item=self.scatterPlot.getScatter('scatter plot'),
plot=self.scatterPlot,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
@@ -255,7 +276,8 @@ class TestStatsFormatter(TestCaseQt):
self.curveContext = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=False)
+ onlimits=False,
+ roi=None)
self.stat = stats.StatMin()
@@ -295,6 +317,7 @@ class TestStatsHandler(TestCaseQt):
self.stat = stats.StatMin()
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot1d.close()
self.plot1d = None
@@ -391,6 +414,7 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
self.statsTable.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
self.statsTable = None
@@ -493,7 +517,6 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
self.qapp.processEvents()
tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
curve0_min = tableItems['min'].text()
- print(curve0_min)
self.assertTrue(float(curve0_min) == -1.)
self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
@@ -581,6 +604,7 @@ class TestStatsWidgetWithImages(TestCaseQt):
self.widget.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
@@ -641,6 +665,7 @@ class TestStatsWidgetWithScatters(TestCaseQt):
self.widget.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.scatterPlot.close()
self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
@@ -694,6 +719,7 @@ class TestLineWidget(TestCaseQt):
stats=mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.qapp.processEvents()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
@@ -806,12 +832,223 @@ class TestUpdateModeWidget(TestCaseQt):
self.assertEqual(manualUpdateListener.callCount(), 2)
+class TestStatsROI(TestStatsBase, TestCaseQt):
+ """
+ Test stats based on ROI
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.createRois()
+ TestStatsBase.setUp(self)
+ self.createHistogramContext()
+
+ self.roiManager = RegionOfInterestManager(self.plot2d)
+ self.roiManager.addRoi(self._2Droi_rect)
+ self.roiManager.addRoi(self._2Droi_poly)
+
+ def tearDown(self):
+ self.roiManager.clear()
+ self.roiManager = None
+ self._1Droi = None
+ self._2Droi_rect = None
+ self._2Droi_poly = None
+ self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plotHisto.close()
+ self.plotHisto = None
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def createRois(self):
+ self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
+ self._2Droi_rect = RectangleROI()
+ self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
+ self._2Droi_poly = PolygonROI()
+ points = numpy.array(((0, 20), (0, 0), (10, 0)))
+ self._2Droi_poly.setPoints(points=points)
+
+ def createCurveContext(self):
+ TestStatsBase.createCurveContext(self)
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createHistogramContext(self):
+ self.plotHisto = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plotHisto.addHistogram(x, y, legend='histo0')
+
+ self.histoContext = stats._HistogramContext(
+ item=self.plotHisto.getHistogram('histo0'),
+ plot=self.plotHisto,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createScatterContext(self):
+ TestStatsBase.createScatterContext(self)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=self._1Droi
+ )
+
+ def createImageContext(self):
+ TestStatsBase.createImageContext(self)
+
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_rect
+ )
+
+ self.imageContext_2 = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_poly
+ )
+
+ def testErrors(self):
+ # test if onlimits is True and give also a roi
+ with self.assertRaises(ValueError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi)
+
+ # test if is a curve context and give an invalid 2D roi
+ with self.assertRaises(TypeError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(0, 10))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
+ com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImageRectRoi(self):
+ """Test result for simple stats on an image"""
+ self.assertEqual(self.imageContext.values.compressed().size, 121)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]))
+ self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]))
+
+ compressed_values = self.imageContext.values.compressed()
+ compressed_values = compressed_values.reshape(11, 11)
+ yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
+ xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
+
+ dataYRange = range(11)
+ dataXRange = range(10, 21)
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testBasicStatsImagePolyRoi(self):
+ """Test a simple rectangle ROI"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
+ # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
+ # on to manage them in stats. For now 0 if the center is not in, else 1
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
+
+ def testBasicStatsScatter(self):
+ self.assertEqual(self.scatterContext.values.compressed().size, 2)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
+
+ def testBasicHistogram(self):
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(2, 6))
+ self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.histoContext), com)
+
+
+class TestAdvancedROIImageContext(TestCaseQt):
+ """Test stats result on an image context with different scale and
+ origins"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.data_dims = (100, 100)
+ self.data = numpy.random.rand(*self.data_dims)
+ self.plot = Plot2D()
+
+ def test(self):
+ """Test stats result on an image context with different scale and
+ origins"""
+ roi_origins = [(0, 0), (2, 10), (14, 20)]
+ img_origins = [(0, 0), (14, 20), (2, 10)]
+ img_scales = [1.0, 0.5, 2.0]
+ _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
+ for roi_origin in roi_origins:
+ for img_origin in img_origins:
+ for img_scale in img_scales:
+ with self.subTest(roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale):
+ self.plot.addImage(self.data, legend='img',
+ origin=img_origin,
+ scale=img_scale)
+ roi = RectangleROI()
+ roi.setGeometry(origin=roi_origin, size=(20, 20))
+ context = stats._ImageContext(
+ item=self.plot.getImage('img'),
+ plot=self.plot,
+ onlimits=False,
+ roi=roi)
+ x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
+ x_end = int(x_start + (20 / img_scale)) + 1
+ y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
+ y_end = int(y_start + (20 / img_scale)) + 1
+ x_start = max(x_start, 0)
+ x_end = min(max(x_end, 0), self.data_dims[1])
+ y_start = max(y_start, 0)
+ y_end = min(max(y_end, 0), self.data_dims[0])
+ th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
+ self.assertAlmostEqual(_stats['sum'].calculate(context),
+ th_sum)
+
def suite():
test_suite = unittest.TestSuite()
for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters,
TestStatsWidgetWithImages, TestStatsWidgetWithCurves,
- TestStatsFormatter, TestEmptyStatsWidget,
- TestLineWidget, TestUpdateModeWidget):
+ TestStatsFormatter, TestEmptyStatsWidget, TestStatsROI,
+ TestLineWidget, TestUpdateModeWidget, ):
test_suite.addTest(
unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
return test_suite
diff --git a/silx/gui/plot/tools/profile/manager.py b/silx/gui/plot/tools/profile/manager.py
index 4d467f0..757b741 100644
--- a/silx/gui/plot/tools/profile/manager.py
+++ b/silx/gui/plot/tools/profile/manager.py
@@ -76,6 +76,17 @@ class _RunnableComputeProfile(qt.QRunnable):
self._signals.moveToThread(threadPool.thread())
self._item = item
self._roi = roi
+ self._cancelled = False
+
+ def _lazyCancel(self):
+ """Cancel the runner if it is not yet started.
+
+ The threadpool will still execute the runner, but this will process
+ nothing.
+
+ This is only used with Qt<5.9 where QThreadPool.tryTake is not available.
+ """
+ self._cancelled = True
def autoDelete(self):
return False
@@ -106,12 +117,13 @@ class _RunnableComputeProfile(qt.QRunnable):
def run(self):
"""Process the profile computation.
"""
- try:
- profileData = self._roi.computeProfile(self._item)
- except Exception:
- _logger.error("Error while computing profile", exc_info=True)
- else:
- self.resultReady.emit(self._roi, profileData)
+ if not self._cancelled:
+ try:
+ profileData = self._roi.computeProfile(self._item)
+ except Exception:
+ _logger.error("Error while computing profile", exc_info=True)
+ else:
+ self.resultReady.emit(self._roi, profileData)
self.runnerFinished.emit(self)
@@ -815,8 +827,11 @@ class ProfileManager(qt.QObject):
self._pendingRunners.remove(runner)
continue
if runner.getRoi() is profileRoi:
- if threadPool.tryTake(runner):
- self._pendingRunners.remove(runner)
+ if hasattr(threadPool, "tryTake"):
+ if threadPool.tryTake(runner):
+ self._pendingRunners.remove(runner)
+ else: # Support Qt<5.9
+ runner._lazyCancel()
item = self.getPlotItem()
if item is None or not isinstance(item, profileRoi.ITEM_KIND):
diff --git a/silx/gui/plot/tools/profile/rois.py b/silx/gui/plot/tools/profile/rois.py
index b49679c..9e651a7 100644
--- a/silx/gui/plot/tools/profile/rois.py
+++ b/silx/gui/plot/tools/profile/rois.py
@@ -137,11 +137,11 @@ class _ImageProfileArea(items.Shape):
if not isinstance(item, items.ImageBase):
raise TypeError("Unexpected class %s" % type(item))
- if isinstance(item, items.ImageData):
- currentData = item.getData(copy=False)
- elif isinstance(item, items.ImageRgba):
+ if isinstance(item, items.ImageRgba):
rgba = item.getData(copy=False)
currentData = rgba[..., 0]
+ else:
+ currentData = item.getData(copy=False)
roi = self.getParentRoi()
origin = item.getOrigin()
@@ -310,15 +310,15 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
method=method)
return coords, profile, profileName, xLabel
- if isinstance(item, items.ImageData):
- currentData = item.getData(copy=False)
- elif isinstance(item, items.ImageRgba):
+ if isinstance(item, items.ImageRgba):
rgba = item.getData(copy=False)
is_uint8 = rgba.dtype.type == numpy.uint8
# luminosity
if is_uint8:
- rgba = rgba.astype(numpy.float)
+ rgba = rgba.astype(numpy.float64)
currentData = 0.21 * rgba[..., 0] + 0.72 * rgba[..., 1] + 0.07 * rgba[..., 2]
+ else:
+ currentData = item.getData(copy=False)
yLabel = "%s" % str(method).capitalize()
coords, profile, title, xLabel = createProfile2(currentData)
diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py
index 431ecb2..4e2d6db 100644
--- a/silx/gui/plot/tools/roi.py
+++ b/silx/gui/plot/tools/roi.py
@@ -34,10 +34,13 @@ import enum
import logging
import time
import weakref
+import functools
import numpy
from ... import qt, icons
+from ...utils import blockSignals
+from ...utils import LockReentrant
from .. import PlotWidget
from ..items import roi as roi_items
@@ -163,6 +166,155 @@ class CreateRoiModeAction(qt.QAction):
pass
+class RoiModeSelector(qt.QWidget):
+ def __init__(self, parent=None):
+ super(RoiModeSelector, self).__init__(parent=parent)
+ self.__roi = None
+ self.__reentrant = LockReentrant()
+
+ layout = qt.QHBoxLayout(self)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ self._label = qt.QLabel(self)
+ self._label.setText("Mode:")
+ self._label.setToolTip("Select a specific interaction to edit the ROI")
+ self._combo = qt.QComboBox(self)
+ self._combo.currentIndexChanged.connect(self._modeSelected)
+ layout.addWidget(self._label)
+ layout.addWidget(self._combo)
+ self._updateAvailableModes()
+
+ def getRoi(self):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ return self.__roi
+
+ def setRoi(self, roi):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ if self.__roi is roi:
+ return
+ if not isinstance(roi, roi_items.InteractionModeMixIn):
+ self.__roi = None
+ self._updateAvailableModes()
+ return
+
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.connect(self._modeChanged)
+ self._updateAvailableModes()
+
+ def isEmpty(self):
+ return not self._label.isVisibleTo(self)
+
+ def _updateAvailableModes(self):
+ roi = self.getRoi()
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ modes = roi.availableInteractionModes()
+ else:
+ modes = []
+ if len(modes) <= 1:
+ self._label.setVisible(False)
+ self._combo.setVisible(False)
+ else:
+ self._label.setVisible(True)
+ self._combo.setVisible(True)
+ with blockSignals(self._combo):
+ self._combo.clear()
+ for im, m in enumerate(modes):
+ self._combo.addItem(m.label, m)
+ self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole)
+ mode = roi.getInteractionMode()
+ self._modeChanged(mode)
+ index = modes.index(mode)
+ self._combo.setCurrentIndex(index)
+
+ def _modeChanged(self, mode):
+ """Triggered when the ROI interaction mode was changed externally"""
+ if self.__reentrant.locked():
+ # This event was initialised by the widget
+ return
+ roi = self.__roi
+ modes = roi.availableInteractionModes()
+ index = modes.index(mode)
+ with blockSignals(self._combo):
+ self._combo.setCurrentIndex(index)
+
+ def _modeSelected(self):
+ """Triggered when the ROI interaction mode was selected in the widget"""
+ index = self._combo.currentIndex()
+ if index == -1:
+ return
+ roi = self.getRoi()
+ if roi is not None:
+ mode = self._combo.itemData(index, qt.Qt.UserRole)
+ with self.__reentrant:
+ roi.setInteractionMode(mode)
+
+
+class RoiModeSelectorAction(qt.QWidgetAction):
+ """Display the selected mode of a ROI and allow to change it"""
+
+ def __init__(self, parent=None):
+ super(RoiModeSelectorAction, self).__init__(parent)
+ self.__roiManager = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new widget"""
+ widget = RoiModeSelector(parent)
+ manager = self.__roiManager
+ if manager is not None:
+ roi = manager.getCurrentRoi()
+ widget.setRoi(roi)
+ self.setVisible(not widget.isEmpty())
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete a widget"""
+ widget.setRoi(None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ self.setRoi(roi)
+
+ def setRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ widget = None
+ for widget in self.createdWidgets():
+ widget.setRoi(roi)
+ if widget is not None:
+ self.setVisible(not widget.isEmpty())
+
+
class RegionOfInterestManager(qt.QObject):
"""Class handling ROI interaction on a PlotWidget.
@@ -257,6 +409,8 @@ class RegionOfInterestManager(qt.QObject):
parent.sigItemRemoved.connect(self._itemRemoved)
+ parent._sigDefaultContextMenu.connect(self._feedContextMenu)
+
@classmethod
def getSupportedRoiClasses(cls):
"""Returns the default available ROI classes
@@ -400,25 +554,87 @@ class RegionOfInterestManager(qt.QObject):
def _plotSignals(self, event):
"""Handle mouse interaction for ROI addition"""
- if event['event'] in ('markerClicked', 'markerMoving'):
+ clicked = False
+ roi = None
+ if event["event"] in ("markerClicked", "markerMoving"):
plot = self.parent()
- legend = event['label']
+ legend = event["label"]
marker = plot._getMarker(legend=legend)
roi = self.__getRoiFromMarker(marker)
- if roi is not None and roi.isSelectable():
- self.setCurrentRoi(roi)
- else:
- self.setCurrentRoi(None)
- elif event['event'] == 'mouseClicked' and event['button'] == 'left':
+ elif event["event"] == "mouseClicked" and event["button"] == "left":
# Marker click is only for dnd
# This also can click on a marker
+ clicked = True
plot = self.parent()
- marker = plot._getMarkerAt(event['xpixel'], event['ypixel'])
+ marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
roi = self.__getRoiFromMarker(marker)
- if roi is not None and roi.isSelectable():
+ else:
+ return
+
+ if roi not in self._rois:
+ # The ROI is not own by this manager
+ return
+
+ if roi is not None:
+ currentRoi = self.getCurrentRoi()
+ if currentRoi is roi:
+ if clicked:
+ self.__updateMode(roi)
+ elif roi.isSelectable():
self.setCurrentRoi(roi)
+ else:
+ self.setCurrentRoi(None)
+
+ def __updateMode(self, roi):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ available = roi.availableInteractionModes()
+ mode = roi.getInteractionMode()
+ imode = available.index(mode)
+ mode = available[(imode + 1) % len(available)]
+ roi.setInteractionMode(mode)
+
+ def _feedContextMenu(self, menu):
+ """Called wen the default plot context menu is about to be displayed"""
+ roi = self.getCurrentRoi()
+ if roi is not None:
+ if roi.isEditable():
+ # Filter by data position
+ # FIXME: It would be better to use GUI coords for it
+ plot = self.parent()
+ pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
+ data = plot.pixelToData(pos.x(), pos.y())
+ if roi.contains(data):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ self._contextMenuForInteractionMode(menu, roi)
+
+ removeAction = qt.QAction(menu)
+ removeAction.setText("Remove %s" % roi.getName())
+ callback = functools.partial(self.removeRoi, roi)
+ removeAction.triggered.connect(callback)
+ menu.addAction(removeAction)
+
+ def _contextMenuForInteractionMode(self, menu, roi):
+ availableModes = roi.availableInteractionModes()
+ currentMode = roi.getInteractionMode()
+ submenu = qt.QMenu(menu)
+ modeGroup = qt.QActionGroup(menu)
+ modeGroup.setExclusive(True)
+ for mode in availableModes:
+ action = qt.QAction(menu)
+ action.setText(mode.label)
+ action.setToolTip(mode.description)
+ action.setCheckable(True)
+ if mode is currentMode:
+ action.setChecked(True)
else:
- self.setCurrentRoi(None)
+ callback = functools.partial(roi.setInteractionMode, mode)
+ action.triggered.connect(callback)
+ modeGroup.addAction(action)
+ submenu.addAction(action)
+ action = qt.QAction(menu)
+ action.setMenu(submenu)
+ action.setText("%s interaction mode" % roi.getName())
+ menu.addAction(action)
# RegionOfInterest API
@@ -666,8 +882,9 @@ class RegionOfInterestManager(qt.QObject):
if self._drawnROI is not None:
# Cancel ROI create
- self.removeRoi(self._drawnROI)
+ roi = self._drawnROI
self._drawnROI = None
+ self.removeRoi(roi)
plot = self.parent()
if plot is not None:
diff --git a/silx/gui/plot/tools/test/testROI.py b/silx/gui/plot/tools/test/testROI.py
index 33a0000..8a00073 100644
--- a/silx/gui/plot/tools/test/testROI.py
+++ b/silx/gui/plot/tools/test/testROI.py
@@ -136,6 +136,31 @@ class TestRoiItems(TestCaseQt):
numpy.testing.assert_allclose(item.getCenter(), center)
numpy.testing.assert_allclose(item.getRadius(), newRadius)
+ def testCircle_contains(self):
+ center = numpy.array([2, -1])
+ radius = 1.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ self.assertTrue(item.contains([1, -1]))
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([2, 0]))
+ self.assertFalse(item.contains([3.01, -1]))
+
+ def testEllipse_contains(self):
+ center = numpy.array([-2, 0])
+ item = roi_items.EllipseROI()
+ item.setCenter(center)
+ item.setOrientation(numpy.pi / 4.0)
+ item.setMajorRadius(2)
+ item.setMinorRadius(1)
+ print(item.getMinorRadius(), item.getMajorRadius())
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([-1, 1]))
+ self.assertTrue(item.contains([-3, 0]))
+ self.assertTrue(item.contains([-2, 0]))
+ self.assertTrue(item.contains([-2, 1]))
+ self.assertFalse(item.contains([-4, 1]))
+
def testRectangle_isIn(self):
origin = numpy.array([0, 0])
size = numpy.array([10, 20])
@@ -557,8 +582,9 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
mx, my = self.plot.dataToPixel(*center)
self.mouseMove(widget, pos=(mx, my))
self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.mouseMove(widget, pos=(mx, my+25))
self.mouseMove(widget, pos=(mx, my+50))
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50))
result = numpy.array(item.getEndPoints())
# x location is still the same
@@ -615,6 +641,45 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
# Clean up
manager.clear()
+ def testArcRoiSwitchMode(self):
+ """Make sure we can switch mode by clicking on the ROI"""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+ size = numpy.abs(points[1] - points[0])
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.ArcROI()
+ item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3)
+ item.setEditable(True)
+ item.setSelectable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Initial state
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+ self.qWait(500)
+
+ # Click on the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+
+ # Select the ROI
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+
+ # Change the mode
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode)
+
+ manager.clear()
+ self.qapp.processEvents()
def suite():
diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py
index 50cba05..b2bb254 100644
--- a/silx/gui/plot3d/ScalarFieldView.py
+++ b/silx/gui/plot3d/ScalarFieldView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -239,7 +239,7 @@ class SelectedRegion(object):
def __init__(self, arrayRange, dataBBox,
translation=(0., 0., 0.),
scale=(1., 1., 1.)):
- self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int)
+ self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int64)
assert self._arrayRange.shape == (3, 2)
assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0])
@@ -1449,7 +1449,7 @@ class ScalarFieldView(Plot3DWindow):
min(self._data.shape[1], max(*yrange))),
(max(0, min(*xrange_)),
min(self._data.shape[2], max(*xrange_))),
- ), dtype=numpy.int)
+ ), dtype=numpy.int64)
# numpy.equal supports None
if not numpy.all(numpy.equal(selectedRange, self._selectedRange)):
diff --git a/silx/gui/plot3d/items/_pick.py b/silx/gui/plot3d/items/_pick.py
index 8494723..0d6a495 100644
--- a/silx/gui/plot3d/items/_pick.py
+++ b/silx/gui/plot3d/items/_pick.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -197,7 +197,7 @@ class PickingResult(_PickingResult):
super(PickingResult, self).__init__(item, indices)
self._objectPositions = numpy.array(
- positions, copy=False, dtype=numpy.float)
+ positions, copy=False, dtype=numpy.float64)
# Store matrices to generate positions on demand
primitive = item._getScenePrimitive()
diff --git a/silx/gui/plot3d/items/core.py b/silx/gui/plot3d/items/core.py
index 1745b2b..ab2ceb6 100644
--- a/silx/gui/plot3d/items/core.py
+++ b/silx/gui/plot3d/items/core.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -400,32 +400,32 @@ class DataItem3D(Item3D):
self._updated(Item3DChangedType.TRANSFORM)
def setRotationCenter(self, x=0., y=0., z=0.):
- """Set the center of rotation of the item.
-
- Position of the rotation center is either a float
- for an absolute position or one of the following
- string to define a position relative to the item's bounding box:
- 'lower', 'center', 'upper'
-
- :param x: rotation center position on the X axis
- :rtype: float or str
- :param y: rotation center position on the Y axis
- :rtype: float or str
- :param z: rotation center position on the Z axis
- :rtype: float or str
- """
- center = []
- for position in (x, y, z):
- if isinstance(position, six.string_types):
- assert position in self._ROTATION_CENTER_TAGS
- else:
- position = float(position)
- center.append(position)
- center = tuple(center)
-
- if center != self._rotationCenter:
- self._rotationCenter = center
- self._updateRotationCenter()
+ """Set the center of rotation of the item.
+
+ Position of the rotation center is either a float
+ for an absolute position or one of the following
+ string to define a position relative to the item's bounding box:
+ 'lower', 'center', 'upper'
+
+ :param x: rotation center position on the X axis
+ :rtype: float or str
+ :param y: rotation center position on the Y axis
+ :rtype: float or str
+ :param z: rotation center position on the Z axis
+ :rtype: float or str
+ """
+ center = []
+ for position in (x, y, z):
+ if isinstance(position, six.string_types):
+ assert position in self._ROTATION_CENTER_TAGS
+ else:
+ position = float(position)
+ center.append(position)
+ center = tuple(center)
+
+ if center != self._rotationCenter:
+ self._rotationCenter = center
+ self._updateRotationCenter()
def getRotationCenter(self):
"""Returns the rotation center set by :meth:`setRotationCenter`.
diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py
index 14cafc8..f512365 100644
--- a/silx/gui/plot3d/items/mixins.py
+++ b/silx/gui/plot3d/items/mixins.py
@@ -141,6 +141,7 @@ class ColormapMixIn(_ColormapMixIn):
self.__sceneColormap.norm = colormap.getNormalization()
self.__sceneColormap.gamma = colormap.getGammaNormalizationParameter()
self.__sceneColormap.range_ = colormap.getColormapRange(self)
+ self.__sceneColormap.nancolor = rgba(colormap.getNaNColor())
class ComplexMixIn(_ComplexMixIn):
diff --git a/silx/gui/plot3d/items/volume.py b/silx/gui/plot3d/items/volume.py
index 6c6562f..f80fea2 100644
--- a/silx/gui/plot3d/items/volume.py
+++ b/silx/gui/plot3d/items/volume.py
@@ -444,7 +444,7 @@ class Isosurface(Item3D):
return None # No intersected triangles
intersections = numpy.array(intersections)[numpy.argsort(depths)]
- indices = numpy.transpose(numpy.round(intersections).astype(numpy.int))
+ indices = numpy.transpose(numpy.round(intersections).astype(numpy.int64))
return PickingResult(self, positions=intersections, indices=indices)
diff --git a/silx/gui/plot3d/scene/cutplane.py b/silx/gui/plot3d/scene/cutplane.py
index 81c74c7..88147df 100644
--- a/silx/gui/plot3d/scene/cutplane.py
+++ b/silx/gui/plot3d/scene/cutplane.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -88,7 +88,7 @@ class ColormapMesh3D(Geometry):
float value = texture3D(data, vTexCoords).r;
vec4 color = $colormapCall(value);
- color.a = alpha;
+ color.a *= alpha;
gl_FragColor = $lightingCall(color, vPosition, vNormal);
diff --git a/silx/gui/plot3d/scene/function.py b/silx/gui/plot3d/scene/function.py
index 69a24dd..2deb785 100644
--- a/silx/gui/plot3d/scene/function.py
+++ b/silx/gui/plot3d/scene/function.py
@@ -389,10 +389,13 @@ class Colormap(event.Notifier, ProgramFunction):
uniform float cmap_parameter;
uniform float cmap_min;
uniform float cmap_oneOverRange;
+ uniform vec4 nancolor;
const float oneOverLog10 = 0.43429448190325176;
vec4 colormap(float value) {
+ float data = value; /* Keep original input value for isnan test */
+
if (cmap_normalization == 1) { /* Log10 mapping */
if (value > 0.0) {
value = clamp(cmap_oneOverRange *
@@ -421,7 +424,12 @@ class Colormap(event.Notifier, ProgramFunction):
$discard
- vec4 color = texture2D(cmap_texture, vec2(value, 0.5));
+ vec4 color;
+ if (data != data) { /* isnan alternative for compatibility with GLSL 1.20 */
+ color = nancolor;
+ } else {
+ color = texture2D(cmap_texture, vec2(value, 0.5));
+ }
return color;
}
""")
@@ -458,9 +466,10 @@ class Colormap(event.Notifier, ProgramFunction):
self._gamma = -1.
self._range = 1., 10.
self._displayValuesBelowMin = True
+ self._nancolor = numpy.array((1., 1., 1., 0.), dtype=numpy.float32)
self._texture = None
- self._update_texture = True
+ self._textureToDiscard = None
if colormap is None:
# default colormap
@@ -468,7 +477,7 @@ class Colormap(event.Notifier, ProgramFunction):
colormap[:] = numpy.arange(256,
dtype=numpy.uint8)[:, numpy.newaxis]
- # Set to param values through properties to go through asserts
+ # Set to values through properties to perform asserts and updates
self.colormap = colormap
self.norm = norm
self.gamma = gamma
@@ -491,10 +500,41 @@ class Colormap(event.Notifier, ProgramFunction):
assert colormap.ndim == 2
assert colormap.shape[1] in (3, 4)
self._colormap = colormap
- self._update_texture = True
+
+ if self._texture is not None and self._texture.name is not None:
+ self._textureToDiscard = self._texture
+
+ data = numpy.empty(
+ (16, self._colormap.shape[0], self._colormap.shape[1]),
+ dtype=self._colormap.dtype)
+ data[:] = self._colormap
+
+ format_ = gl.GL_RGBA if data.shape[-1] == 4 else gl.GL_RGB
+
+ self._texture = _glutils.Texture(
+ format_, data, format_,
+ texUnit=self._COLORMAP_TEXTURE_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+
self.notify()
@property
+ def nancolor(self):
+ """RGBA color to use for Not-A-Number values as 4 float in [0., 1.]"""
+ return self._nancolor
+
+ @nancolor.setter
+ def nancolor(self, color):
+ color = numpy.clip(numpy.array(color, dtype=numpy.float32), 0., 1.)
+ assert color.ndim == 1
+ assert len(color) == 4
+ if not numpy.array_equal(self._nancolor, color):
+ self._nancolor = color
+ self.notify()
+
+ @property
def norm(self):
"""Normalization to use for colormap mapping.
@@ -576,9 +616,6 @@ class Colormap(event.Notifier, ProgramFunction):
"""
self.prepareGL2(context) # TODO see how to handle
- if self._texture is None: # No colormap
- return
-
self._texture.bind()
gl.glUniform1i(program.uniforms['cmap_texture'],
@@ -607,23 +644,11 @@ class Colormap(event.Notifier, ProgramFunction):
gl.glUniform1f(program.uniforms['cmap_min'], min_)
gl.glUniform1f(program.uniforms['cmap_oneOverRange'],
(1. / (max_ - min_)) if max_ != min_ else 0.)
+ gl.glUniform4f(program.uniforms['nancolor'], *self._nancolor)
def prepareGL2(self, context):
- if self._texture is None or self._update_texture:
- if self._texture is not None:
- self._texture.discard()
-
- colormap = numpy.empty(
- (16, self._colormap.shape[0], self._colormap.shape[1]),
- dtype=self._colormap.dtype)
- colormap[:] = self._colormap
-
- format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB
-
- self._texture = _glutils.Texture(
- format_, colormap, format_,
- texUnit=self._COLORMAP_TEXTURE_UNIT,
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=gl.GL_CLAMP_TO_EDGE)
- self._update_texture = False
+ if self._textureToDiscard is not None:
+ self._textureToDiscard.discard()
+ self._textureToDiscard = None
+
+ self._texture.prepare()
diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py
index 7db61e8..b4c8e26 100644
--- a/silx/gui/plot3d/scene/primitives.py
+++ b/silx/gui/plot3d/scene/primitives.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -49,7 +49,7 @@ from . import event
from . import core
from . import transform
from . import utils
-from .function import Colormap, Fog
+from .function import Colormap
_logger = logging.getLogger(__name__)
@@ -367,7 +367,7 @@ class Geometry(core.Elem):
min_ = numpy.nanmin(attribute, axis=0)
max_ = numpy.nanmax(attribute, axis=0)
else:
- min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32)
+ min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32)
toCopy = min(len(min_), 3-index)
if toCopy != len(min_):
@@ -2077,7 +2077,7 @@ class _Image(Geometry):
self._update_texture = True
# By updating the position rather than always using a unit square
# we benefit from Geometry bounds handling
- self.setAttribute('position', self._UNIT_SQUARE * self._data.shape[:2])
+ self.setAttribute('position', self._UNIT_SQUARE * (self._data.shape[1], self._data.shape[0]))
self.notify()
def getData(self, copy=True):
@@ -2188,7 +2188,7 @@ class _Image(Geometry):
gl.glUniform1f(program.uniforms['alpha'], self._alpha)
shape = self._data.shape
- gl.glUniform2f(program.uniforms['dataScale'], 1./shape[0], 1./shape[1])
+ gl.glUniform2f(program.uniforms['dataScale'], 1./shape[1], 1./shape[0])
gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
diff --git a/silx/gui/plot3d/scene/text.py b/silx/gui/plot3d/scene/text.py
index c2983d5..bacc2e6 100644
--- a/silx/gui/plot3d/scene/text.py
+++ b/silx/gui/plot3d/scene/text.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -251,6 +251,7 @@ class Text2D(primitives.Geometry):
minFilter=gl.GL_NEAREST,
magFilter=gl.GL_NEAREST,
wrap=gl.GL_CLAMP_TO_EDGE)
+ self._texture.prepare()
self._dirtyAlign = True # To force update of offset
if self._dirtyAlign:
diff --git a/silx/gui/plot3d/scene/transform.py b/silx/gui/plot3d/scene/transform.py
index 1b82397..43b739b 100644
--- a/silx/gui/plot3d/scene/transform.py
+++ b/silx/gui/plot3d/scene/transform.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -855,13 +855,13 @@ class _Projection(Transform):
class Orthographic(_Projection):
- """Orthographic (i.e., parallel) projection which keeps aspect ratio.
+ """Orthographic (i.e., parallel) projection which can keep aspect ratio.
Clipping planes are adjusted to match the aspect ratio of
- the :attr:`size` attribute.
+ the :attr:`size` attribute if :attr:`keepaspect` is True.
- The left, right, bottom and top parameters defines the area which must
- always remain visible.
+ In this case, the left, right, bottom and top parameters defines the area
+ which must always remain visible.
Effective clipping planes are adjusted to keep the aspect ratio.
:param float left: Coord of the left clipping plane.
@@ -873,12 +873,15 @@ class Orthographic(_Projection):
:param size:
Viewport's size used to compute the aspect ratio (width, height).
:type size: 2-tuple of float
+ :param bool keepaspect:
+ True (default) to keep aspect ratio, False otherwise.
"""
def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1.,
- size=(1., 1.)):
+ size=(1., 1.), keepaspect=True):
self._left, self._right = left, right
self._bottom, self._top = bottom, top
+ self._keepaspect = bool(keepaspect)
super(Orthographic, self).__init__(near, far, checkDepthExtent=False,
size=size)
# _update called when setting size
@@ -888,22 +891,23 @@ class Orthographic(_Projection):
self.left, self.right, self.bottom, self.top, self.near, self.far)
def _update(self, left, right, bottom, top):
- width, height = self.size
- aspect = width / height
+ if self.keepaspect:
+ width, height = self.size
+ aspect = width / height
- orthoaspect = abs(left - right) / abs(bottom - top)
+ orthoaspect = abs(left - right) / abs(bottom - top)
- if orthoaspect >= aspect: # Keep width, enlarge height
- newheight = \
- numpy.sign(top - bottom) * abs(left - right) / aspect
- bottom = 0.5 * (bottom + top) - 0.5 * newheight
- top = bottom + newheight
+ if orthoaspect >= aspect: # Keep width, enlarge height
+ newheight = \
+ numpy.sign(top - bottom) * abs(left - right) / aspect
+ bottom = 0.5 * (bottom + top) - 0.5 * newheight
+ top = bottom + newheight
- else: # Keep height, enlarge width
- newwidth = \
- numpy.sign(right - left) * abs(bottom - top) * aspect
- left = 0.5 * (left + right) - 0.5 * newwidth
- right = left + newwidth
+ else: # Keep height, enlarge width
+ newwidth = \
+ numpy.sign(right - left) * abs(bottom - top) * aspect
+ left = 0.5 * (left + right) - 0.5 * newwidth
+ right = left + newwidth
# Store values
self._left, self._right = left, right
@@ -942,15 +946,30 @@ class Orthographic(_Projection):
@property
def size(self):
- """Viewport size as a 2-tuple of float (width, height) or None."""
+ """Viewport size as a 2-tuple of float (width, height)"""
return self._size
@size.setter
def size(self, size):
assert len(size) == 2
- self._size = float(size[0]), float(size[1])
- self._update(self.left, self.right, self.bottom, self.top)
- self.notify()
+ size = float(size[0]), float(size[1])
+ if size != self._size:
+ self._size = size
+ self._update(self.left, self.right, self.bottom, self.top)
+ self.notify()
+
+ @property
+ def keepaspect(self):
+ """True to keep aspect ratio, False otherwise."""
+ return self._keepaspect
+
+ @keepaspect.setter
+ def keepaspect(self, aspect):
+ aspect = bool(aspect)
+ if aspect != self._keepaspect:
+ self._keepaspect = aspect
+ self._update(self.left, self.right, self.bottom, self.top)
+ self.notify()
class Ortho2DWidget(_Projection):
diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py
index bddbcac..c6cd129 100644
--- a/silx/gui/plot3d/scene/utils.py
+++ b/silx/gui/plot3d/scene/utils.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -540,7 +540,7 @@ def segmentVolumeIntersect(segment, nbins):
# bin edges/line intersection points
points = t.reshape(-1, 1) * delta + p0
centers = (points[:-1] + points[1:]) / 2.
- bins = numpy.floor(centers).astype(numpy.int)
+ bins = numpy.floor(centers).astype(numpy.int64)
return bins
diff --git a/silx/gui/plot3d/test/testStatsWidget.py b/silx/gui/plot3d/test/testStatsWidget.py
index 1157aec..bcab1a4 100644
--- a/silx/gui/plot3d/test/testStatsWidget.py
+++ b/silx/gui/plot3d/test/testStatsWidget.py
@@ -34,6 +34,7 @@ import numpy
from silx.utils.testutils import ParametricTestCase
from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.stats.stats import Stats
from silx.gui import qt
from silx.gui.plot.StatsWidget import BasicStatsWidget
@@ -55,6 +56,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
# self.qWaitForWindowExposed(self.sceneWidget)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.qapp.processEvents()
self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
self.sceneWidget.close()
@@ -147,6 +149,7 @@ class TestScalarFieldView(TestCaseQt):
# self.qWaitForWindowExposed(self.sceneWidget)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.qapp.processEvents()
self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose)
self.scalarFieldView.close()
diff --git a/silx/gui/test/test_colors.py b/silx/gui/test/test_colors.py
index f83ff58..9e23a93 100755
--- a/silx/gui/test/test_colors.py
+++ b/silx/gui/test/test_colors.py
@@ -113,6 +113,20 @@ class TestApplyColormapToData(ParametricTestCase):
self.assertEqual(len(value), 1)
self.assertEqual(value[0, 0], 128)
+ def testNaNColor(self):
+ """Test Colormap.applyToData with NaN values"""
+ colormap = Colormap(name='gray', normalization='linear')
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+
+ data = numpy.array([50., numpy.nan])
+ image = items.ImageData()
+ image.setData(numpy.array([[0, 100]]))
+ value = colormap.applyToData(data, reference=image)
+ self.assertEqual(len(value), 2)
+ self.assertTrue(numpy.array_equal(value[0], (128, 128, 128, 255)))
+ self.assertTrue(numpy.array_equal(value[1], (255, 0, 0, 255)))
+
class TestDictAPI(unittest.TestCase):
"""Make sure the old dictionary API is working
@@ -436,9 +450,10 @@ class TestObjectAPI(ParametricTestCase):
Colormap(name="viridis"),
Colormap(normalization=Colormap.SQRT)
]
- gamma = Colormap(normalization=Colormap.GAMMA)
- gamma.setGammaNormalizationParameter(1.2)
- colormaps.append(gamma)
+ cmap = Colormap(normalization=Colormap.GAMMA)
+ cmap.setGammaNormalizationParameter(1.2)
+ cmap.setNaNColor('red')
+ colormaps.append(cmap)
for expected in colormaps:
with self.subTest(colormap=expected):
state = expected.saveState()
@@ -459,6 +474,21 @@ class TestObjectAPI(ParametricTestCase):
expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
self.assertEqual(colormap, expected)
+ def testStorageV2(self):
+ state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\
+ b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\
+ b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\
+ b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x'
+ state = qt.QByteArray(state)
+ colormap = Colormap()
+ colormap.restoreState(state)
+
+ expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ expected.setGammaNormalizationParameter(1.5)
+ self.assertEqual(colormap, expected)
+
+
class TestPreferredColormaps(unittest.TestCase):
"""Test get|setPreferredColormaps functions"""
@@ -540,20 +570,25 @@ class TestAutoscaleRange(ParametricTestCase):
def testAutoscaleRange(self):
nan = numpy.nan
+ data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])
+ data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan])
data = [
# Positive values
(Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)),
(Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (-80, 190)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (1, 1000)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)),
+ (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+
# With nan
(Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)),
(Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100, nan]), (-80, 190)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, nan]), (1, 1000)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)),
# With negative
(Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (1, 1000)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)),
]
for norm, mode, array, expectedRange in data:
with self.subTest(norm=norm, mode=mode, array=array):
diff --git a/silx/gui/utils/glutils.py b/silx/gui/utils/glutils.py
index fca9a32..83cfd89 100644
--- a/silx/gui/utils/glutils.py
+++ b/silx/gui/utils/glutils.py
@@ -27,6 +27,13 @@
import os
import sys
+
+if __name__ == "__main__":
+ # When run as a script, remove directory from sys.path
+ # This avoids other script in same directory to override Python modules
+ if os.path.abspath(sys.path[0]) == os.path.abspath(os.path.dirname(__file__)):
+ sys.path.pop(0)
+
import subprocess
from silx.gui import qt
diff --git a/silx/gui/utils/matplotlib.py b/silx/gui/utils/matplotlib.py
new file mode 100644
index 0000000..484e01a
--- /dev/null
+++ b/silx/gui/utils/matplotlib.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import
+
+"""This module initializes matplotlib and sets-up the backend to use.
+
+It MUST be imported prior to any other import of matplotlib.
+
+It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
+to the used backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/05/2018"
+
+
+from pkg_resources import parse_version
+import matplotlib
+
+from .. import qt
+
+
+def _matplotlib_use(backend, force):
+ """Wrapper of `matplotlib.use` to set-up backend.
+
+ It adds extra initialization for PySide and PySide2 with matplotlib < 2.2.
+ """
+ # This is kept for compatibility with matplotlib < 2.2
+ if parse_version(matplotlib.__version__) < parse_version('2.2'):
+ if qt.BINDING == 'PySide':
+ matplotlib.rcParams['backend.qt4'] = 'PySide'
+ if qt.BINDING == 'PySide2':
+ matplotlib.rcParams['backend.qt5'] = 'PySide2'
+
+ matplotlib.use(backend, force=force)
+
+
+if qt.BINDING in ('PyQt4', 'PySide'):
+ _matplotlib_use('Qt4Agg', force=False)
+ from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa
+
+elif qt.BINDING in ('PyQt5', 'PySide2'):
+ _matplotlib_use('Qt5Agg', force=False)
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
diff --git a/silx/gui/utils/signal.py b/silx/gui/utils/signal.py
new file mode 100644
index 0000000..359f5cc
--- /dev/null
+++ b/silx/gui/utils/signal.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2012 University of North Carolina at Chapel Hill, Luke Campagnola
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils relative to qt Signal
+"""
+
+from silx.gui import qt
+import weakref
+from time import time
+from silx.gui.utils import concurrent
+
+__all__ = ['SignalProxy']
+__authors__ = ['L. Campagnola', 'M. Liberty']
+__license__ = "MIT"
+
+
+class SignalProxy(qt.QObject):
+ """
+ This peace of code come from pyqtgraph
+ Object which collects rapid-fire signals and condenses them
+ into a single signal or a rate-limited stream of signals.
+ Used, for example, to prevent a SpinBox from generating multiple
+ signals when the mouse wheel is rolled over it.
+
+ Emits sigDelayed after input signals have stopped for a certain period of time.
+ """
+
+ sigDelayed = qt.Signal(object)
+
+ def __init__(self, signal, delay=0.3, rateLimit=0, slot=None):
+ """Initialization arguments:
+ signal - a bound Signal or pyqtSignal instance
+ delay - Time (in seconds) to wait for signals to stop before emitting (default 0.3s)
+ slot - Optional function to connect sigDelayed to.
+ rateLimit - (signals/sec) if greater than 0, this allows signals to stream out at a
+ steady rate while they are being received.
+ """
+
+ qt.QObject.__init__(self)
+ signal.connect(self.signalReceived)
+ self.signal = signal
+ self.delay = delay
+ self.rateLimit = rateLimit
+ self.args = None
+ self.timer = qt.QTimer()
+ self.timer.timeout.connect(self.flush)
+ self.blockSignal = False
+ self.slot = weakref.ref(slot)
+ self.lastFlushTime = None
+ if slot is not None:
+ self.sigDelayed.connect(slot)
+
+ def setDelay(self, delay):
+ self.delay = delay
+
+ def signalReceived(self, *args):
+ """Received signal. Cancel previous timer and store args to be forwarded later."""
+ if self.blockSignal:
+ return
+ self.args = args
+ if self.rateLimit == 0:
+ concurrent.submitToQtMainThread(self.timer.stop)
+ concurrent.submitToQtMainThread(self.timer.start, (self.delay * 1000) + 1)
+ else:
+ now = time()
+ if self.lastFlushTime is None:
+ leakTime = 0
+ else:
+ lastFlush = self.lastFlushTime
+ leakTime = max(0, (lastFlush + (1.0 / self.rateLimit)) - now)
+
+ concurrent.submitToQtMainThread(self.timer.stop)
+ concurrent.submitToQtMainThread(self.timer.start, (min(leakTime, self.delay) * 1000) + 1)
+ # self.timer.stop()
+ # self.timer.start((min(leakTime, self.delay) * 1000) + 1)
+
+ def flush(self):
+ """If there is a signal queued up, send it now."""
+ if self.args is None or self.blockSignal:
+ return False
+ args, self.args = self.args, None
+ concurrent.submitToQtMainThread(self.timer.stop)
+ self.lastFlushTime = time()
+ # self.emit(self.signal, *self.args)
+ concurrent.submitToQtMainThread(self.sigDelayed.emit, args)
+ # self.sigDelayed.emit(args)
+ return True
+
+ def disconnect(self):
+ self.blockSignal = True
+ try:
+ self.signal.disconnect(self.signalReceived)
+ except:
+ pass
+ try:
+ self.sigDelayed.disconnect(self.slot)
+ except:
+ pass
+
+
+if __name__ == '__main__':
+ app = qt.QApplication([])
+ win = qt.QMainWindow()
+ spin = qt.QSpinBox()
+ win.setCentralWidget(spin)
+ win.show()
+
+
+ def fn(*args):
+ print("Raw signal:", args)
+
+
+ def fn2(*args):
+ print("Delayed signal:", args)
+
+
+ spin.valueChanged.connect(fn)
+ # proxy = proxyConnect(spin, QtCore.SIGNAL('valueChanged(int)'), fn)
+ proxy = SignalProxy(spin.valueChanged, delay=0.5, slot=fn2)
diff --git a/silx/gui/utils/testutils.py b/silx/gui/utils/testutils.py
index c086657..30b9e34 100644
--- a/silx/gui/utils/testutils.py
+++ b/silx/gui/utils/testutils.py
@@ -142,8 +142,6 @@ class TestCaseQt(unittest.TestCase):
@classmethod
def tearDownClass(cls):
sys.excepthook = cls._oldExceptionHook
- if cls._qapp is not None:
- cls._qapp = None
def setUp(self):
"""Get the list of existing widgets."""
diff --git a/silx/gui/widgets/ElidedLabel.py b/silx/gui/widgets/ElidedLabel.py
index 58513c7..fe53bb9 100644
--- a/silx/gui/widgets/ElidedLabel.py
+++ b/silx/gui/widgets/ElidedLabel.py
@@ -61,12 +61,12 @@ class ElidedLabel(qt.QLabel):
self.__updateText()
def __updateMinimumSize(self):
- metrics = qt.QFontMetrics(self.font())
+ metrics = self.fontMetrics()
width = metrics.width("...")
self.setMinimumWidth(width)
def __updateText(self):
- metrics = qt.QFontMetrics(self.font())
+ metrics = self.fontMetrics()
elidedText = metrics.elidedText(self.__text, self.__elideMode, self.width())
qt.QLabel.setText(self, elidedText)
wasElided = self.__textIsElided
diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py
index b868171..9aaec76 100644
--- a/silx/gui/widgets/test/__init__.py
+++ b/silx/gui/widgets/test/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,6 +34,7 @@ from . import test_boxlayoutdockwidget
from . import test_rangeslider
from . import test_flowlayout
from . import test_elidedlabel
+from . import test_legendiconwidget
__authors__ = ["V. Valls", "P. Knobel"]
__license__ = "MIT"
@@ -53,5 +54,6 @@ def suite():
test_rangeslider.suite(),
test_flowlayout.suite(),
test_elidedlabel.suite(),
+ test_legendiconwidget.suite(),
])
return test_suite
diff --git a/silx/gui/widgets/test/test_legendiconwidget.py b/silx/gui/widgets/test/test_legendiconwidget.py
new file mode 100644
index 0000000..f845f75
--- /dev/null
+++ b/silx/gui/widgets/test/test_legendiconwidget.py
@@ -0,0 +1,74 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for LegendIconWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/10/2020"
+
+import unittest
+
+from silx.gui import qt
+from silx.gui.widgets.LegendIconWidget import LegendIconWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+
+
+class TestLegendIconWidget(TestCaseQt, ParametricTestCase):
+ """Tests for TestRangeSlider"""
+
+ def setUp(self):
+ self.widget = LegendIconWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ self.qapp.processEvents()
+
+ def testCreate(self):
+ self.qapp.processEvents()
+
+ def testColormap(self):
+ self.widget.setColormap("viridis")
+ self.qapp.processEvents()
+
+ def testSymbol(self):
+ self.widget.setSymbol("o")
+ self.widget.setSymbolColormap("viridis")
+ self.qapp.processEvents()
+
+
+def suite():
+ loader = unittest.defaultTestLoader.loadTestsFromTestCase
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(loader(TestLegendIconWidget))
+ return test_suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/silx/image/marchingsquares/_mergeimpl.pyx b/silx/image/marchingsquares/_mergeimpl.pyx
index 7286a66..5a7a3b5 100644
--- a/silx/image/marchingsquares/_mergeimpl.pyx
+++ b/silx/image/marchingsquares/_mergeimpl.pyx
@@ -1,6 +1,6 @@
# coding: utf-8
# /*##########################################################################
-# Copyright (C) 2018 European Synchrotron Radiation Facility
+# Copyright (C) 2018-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -48,7 +48,7 @@ cimport libc.string
cimport cython
-include "../../utils/_have_openmp.pxi"
+from ...utils._have_openmp cimport COMPILED_WITH_OPENMP
"""Store in the module if it was compiled with OpenMP"""
cdef double EPSILON = numpy.finfo(numpy.float64).eps
diff --git a/silx/image/tomography.py b/silx/image/tomography.py
index c2aedd8..53855c1 100644
--- a/silx/image/tomography.py
+++ b/silx/image/tomography.py
@@ -32,6 +32,7 @@ __date__ = "12/09/2017"
import numpy as np
from math import pi
+from functools import lru_cache
from itertools import product
from bisect import bisect
from silx.math.fit import leastsq
@@ -128,6 +129,7 @@ def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.):
return filt_f
+@lru_cache(maxsize=1)
def generate_powers():
"""
Generate a list of powers of [2, 3, 5, 7],
diff --git a/silx/io/commonh5.py b/silx/io/commonh5.py
index b624816..57232d8 100644
--- a/silx/io/commonh5.py
+++ b/silx/io/commonh5.py
@@ -1,6 +1,6 @@
# coding: utf-8
# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -376,6 +376,24 @@ class Dataset(Node):
There is no chunks."""
return None
+ @property
+ def is_virtual(self):
+ """Checks virtual data as provided by `h5py.Dataset`"""
+ return False
+
+ def virtual_sources(self):
+ """Returns virtual dataset sources as provided by `h5py.Dataset`.
+
+ :rtype: list"""
+ raise RuntimeError("Not a virtual dataset")
+
+ @property
+ def external(self):
+ """Returns external sources as provided by `h5py.Dataset`.
+
+ :rtype: list or None"""
+ return None
+
def __array__(self, dtype=None):
# Special case for (0,)*-shape datasets
if numpy.product(self.shape) == 0:
@@ -958,7 +976,7 @@ class Group(Node):
raise TypeError("Path are not supported")
if data is None:
if dtype is None:
- dtype = numpy.float
+ dtype = numpy.float64
data = numpy.empty(shape=shape, dtype=dtype)
elif dtype is not None:
data = data.astype(dtype)
diff --git a/silx/io/dictdump.py b/silx/io/dictdump.py
index f2318e0..bbb244a 100644
--- a/silx/io/dictdump.py
+++ b/silx/io/dictdump.py
@@ -34,9 +34,11 @@ import sys
import h5py
from .configdict import ConfigDict
-from .utils import is_group
+from .utils import is_group, is_link, is_softlink, is_externallink
from .utils import is_file as is_h5_file_like
from .utils import open as h5open
+from .utils import h5py_read_dataset
+from .utils import H5pyAttributesReadWrapper
__authors__ = ["P. Knobel"]
__license__ = "MIT"
@@ -44,35 +46,24 @@ __date__ = "17/07/2018"
logger = logging.getLogger(__name__)
-string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
+vlen_utf8 = h5py.special_dtype(vlen=str)
+vlen_bytes = h5py.special_dtype(vlen=bytes)
-def _prepare_hdf5_dataset(array_like):
+def _prepare_hdf5_write_value(array_like):
"""Cast a python object into a numpy array in a HDF5 friendly format.
:param array_like: Input dataset in a type that can be digested by
``numpy.array()`` (`str`, `list`, `numpy.ndarray`…)
:return: ``numpy.ndarray`` ready to be written as an HDF5 dataset
"""
- # simple strings
- if isinstance(array_like, string_types):
- array_like = numpy.string_(array_like)
-
- # Ensure our data is a numpy.ndarray
- if not isinstance(array_like, (numpy.ndarray, numpy.string_)):
- array = numpy.array(array_like)
+ array = numpy.asarray(array_like)
+ if numpy.issubdtype(array.dtype, numpy.bytes_):
+ return numpy.array(array_like, dtype=vlen_bytes)
+ elif numpy.issubdtype(array.dtype, numpy.str_):
+ return numpy.array(array_like, dtype=vlen_utf8)
else:
- array = array_like
-
- # handle list of strings or numpy array of strings
- if not isinstance(array, numpy.string_):
- data_kind = array.dtype.kind
- # unicode: convert to byte strings
- # (http://docs.h5py.org/en/latest/strings.html)
- if data_kind.lower() in ["s", "u"]:
- array = numpy.asarray(array, dtype=numpy.string_)
-
- return array
+ return array
class _SafeH5FileWrite(object):
@@ -219,150 +210,145 @@ def dicttoh5(treedict, h5file, h5path='/',
h5f.create_group(h5path)
for key in filter(lambda k: not isinstance(k, tuple), treedict):
- if isinstance(treedict[key], dict) and len(treedict[key]):
+ key_is_group = isinstance(treedict[key], dict)
+ h5name = h5path + key
+
+ if key_is_group and treedict[key]:
# non-empty group: recurse
- dicttoh5(treedict[key], h5f, h5path + key,
+ dicttoh5(treedict[key], h5f, h5name,
overwrite_data=overwrite_data,
create_dataset_args=create_dataset_args)
+ continue
- elif treedict[key] is None or (isinstance(treedict[key], dict) and
- not len(treedict[key])):
- if (h5path + key) in h5f:
- if overwrite_data is True:
- del h5f[h5path + key]
- else:
- logger.warning('key (%s) already exists. '
- 'Not overwriting.' % (h5path + key))
- continue
- # Create empty group
- h5f.create_group(h5path + key)
+ if h5name in h5f:
+ # key already exists: delete or skip
+ if overwrite_data is True:
+ del h5f[h5name]
+ else:
+ logger.warning('key (%s) already exists. '
+ 'Not overwriting.' % (h5name))
+ continue
+
+ value = treedict[key]
+ if value is None or key_is_group:
+ # Create empty group
+ h5f.create_group(h5name)
+ elif is_link(value):
+ h5f[h5name] = value
else:
- ds = _prepare_hdf5_dataset(treedict[key])
+ data = _prepare_hdf5_write_value(value)
# can't apply filters on scalars (datasets with shape == () )
- if ds.shape == () or create_dataset_args is None:
- if h5path + key in h5f:
- if overwrite_data is True:
- del h5f[h5path + key]
- else:
- logger.warning('key (%s) already exists. '
- 'Not overwriting.' % (h5path + key))
- continue
-
- h5f.create_dataset(h5path + key,
- data=ds)
+ if data.shape == () or create_dataset_args is None:
+ h5f.create_dataset(h5name,
+ data=data)
else:
- if h5path + key in h5f:
- if overwrite_data is True:
- del h5f[h5path + key]
- else:
- logger.warning('key (%s) already exists. '
- 'Not overwriting.' % (h5path + key))
- continue
-
- h5f.create_dataset(h5path + key,
- data=ds,
+ h5f.create_dataset(h5name,
+ data=data,
**create_dataset_args)
# deal with h5 attributes which have tuples as keys in treedict
for key in filter(lambda k: isinstance(k, tuple), treedict):
- if (h5path + key[0]) not in h5f:
+ assert len(key) == 2, "attribute must be defined by 2 values"
+ h5name = h5path + key[0]
+ attr_name = key[1]
+
+ if h5name not in h5f:
# Create empty group if key for attr does not exist
- h5f.create_group(h5path + key[0])
+ h5f.create_group(h5name)
logger.warning(
"key (%s) does not exist. attr %s "
- "will be written to ." % (h5path + key[0], key[1])
+ "will be written to ." % (h5name, attr_name)
)
- if key[1] in h5f[h5path + key[0]].attrs:
+ if attr_name in h5f[h5name].attrs:
if not overwrite_data:
logger.warning(
"attribute %s@%s already exists. Not overwriting."
- "" % (h5path + key[0], key[1])
+ "" % (h5name, attr_name)
)
continue
# Write attribute
value = treedict[key]
+ data = _prepare_hdf5_write_value(value)
+ h5f[h5name].attrs[attr_name] = data
- # Makes list/tuple of str being encoded as vlen unicode array
- # Workaround for h5py<2.9.0 (e.g. debian 10).
- if (isinstance(value, (list, tuple)) and
- numpy.asarray(value).dtype.type == numpy.unicode_):
- value = numpy.array(value, dtype=h5py.special_dtype(vlen=str))
-
- h5f[h5path + key[0]].attrs[key[1]] = value
-
-def dicttonx(
- treedict,
- h5file,
- h5path="/",
- mode="w",
- overwrite_data=False,
- create_dataset_args=None,
-):
- """
- Write a nested dictionary to a HDF5 file, using string keys as member names.
- The NeXus convention is used to identify attributes with ``"@"`` character,
- therefor the dataset_names should not contain ``"@"``.
+def nexus_to_h5_dict(treedict, parents=tuple()):
+ """The following conversions are applied:
+ * key with "{name}@{attr_name}" notation: key converted to 2-tuple
+ * key with ">{url}" notation: strip ">" and convert value to
+ h5py.SoftLink or h5py.ExternalLink
:param treedict: Nested dictionary/tree structure with strings as keys
and array-like objects as leafs. The ``"/"`` character can be used
to define sub tree. The ``"@"`` character is used to write attributes.
+ The ``">"`` prefix is used to define links.
+ :param parents: Needed to resolve up-links (tuple of HDF5 group names)
- Detais on all other params can be found in doc of dicttoh5.
+ :rtype dict:
+ """
+ copy = dict()
+ for key, value in treedict.items():
+ if "@" in key:
+ key = tuple(key.rsplit("@", 1))
+ elif key.startswith(">"):
+ if isinstance(value, str):
+ key = key[1:]
+ first, sep, second = value.partition("::")
+ if sep:
+ value = h5py.ExternalLink(first, second)
+ else:
+ if ".." in first:
+ # Up-links not supported: make absolute
+ parts = []
+ for p in list(parents) + first.split("/"):
+ if not p or p == ".":
+ continue
+ elif p == "..":
+ parts.pop(-1)
+ else:
+ parts.append(p)
+ first = "/" + "/".join(parts)
+ value = h5py.SoftLink(first)
+ elif is_link(value):
+ key = key[1:]
+ if isinstance(value, dict):
+ copy[key] = nexus_to_h5_dict(value, parents=parents+(key,))
+ else:
+ copy[key] = value
+ return copy
- Example::
- import numpy
- from silx.io.dictdump import dicttonx
+def h5_to_nexus_dict(treedict):
+ """The following conversions are applied:
+ * 2-tuple key: converted to string ("@" notation)
+ * h5py.Softlink value: converted to string (">" key prefix)
+ * h5py.ExternalLink value: converted to string (">" key prefix)
- gauss = {
- "entry":{
- "title":u"A plot of a gaussian",
- "plot": {
- "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1.,
- 0.9, 0.66, 0.39, 0.19, 0.08]),
- "x": numpy.arange(0,1.1,.1),
- "@signal": "y",
- "@axes": "x",
- "@NX_class":u"NXdata",
- "title:u"Gauss Plot",
- },
- "@NX_class":u"NXentry",
- "default":"plot",
- }
- "@NX_class": u"NXroot",
- "@default": "entry",
- }
+ :param treedict: Nested dictionary/tree structure with strings as keys
+ and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub tree.
- dicttonx(gauss,"test.h5")
+ :rtype dict:
"""
-
- def copy_keys_keep_values(original):
- # create a new treedict with with modified keys but keep values
- copy = dict()
- for key, value in original.items():
- if "@" in key:
- newkey = tuple(key.rsplit("@", 1))
- else:
- newkey = key
- if isinstance(value, dict):
- copy[newkey] = copy_keys_keep_values(value)
- else:
- copy[newkey] = value
- return copy
-
- nxtreedict = copy_keys_keep_values(treedict)
- dicttoh5(
- nxtreedict,
- h5file,
- h5path=h5path,
- mode=mode,
- overwrite_data=overwrite_data,
- create_dataset_args=create_dataset_args,
- )
+ copy = dict()
+ for key, value in treedict.items():
+ if isinstance(key, tuple):
+ assert len(key)==2, "attribute must be defined by 2 values"
+ key = "%s@%s" % (key[0], key[1])
+ elif is_softlink(value):
+ key = ">" + key
+ value = value.path
+ elif is_externallink(value):
+ key = ">" + key
+ value = value.filename + "::" + value.path
+ if isinstance(value, dict):
+ copy[key] = h5_to_nexus_dict(value)
+ else:
+ copy[key] = value
+ return copy
def _name_contains_string_in_list(name, strlist):
@@ -374,7 +360,31 @@ def _name_contains_string_in_list(name, strlist):
return False
-def h5todict(h5file, path="/", exclude_names=None, asarray=True):
+def _handle_error(mode: str, exception, msg: str, *args) -> None:
+ """Handle errors.
+
+ :param str mode: 'raise', 'log', 'ignore'
+ :param type exception: Exception class to use in 'raise' mode
+ :param str msg: Error message template
+ :param List[str] args: Arguments for error message template
+ """
+ if mode == 'ignore':
+ return # no-op
+ elif mode == 'log':
+ logger.error(msg, *args)
+ elif mode == 'raise':
+ raise exception(msg % args)
+ else:
+ raise ValueError("Unsupported error handling: %s" % mode)
+
+
+def h5todict(h5file,
+ path="/",
+ exclude_names=None,
+ asarray=True,
+ dereference_links=True,
+ include_attributes=False,
+ errors='raise'):
"""Read a HDF5 file and return a nested dictionary with the complete file
structure and all data.
@@ -397,7 +407,7 @@ def h5todict(h5file, path="/", exclude_names=None, asarray=True):
.. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
installed.
- .. note:: If you write a dictionary to a HDF5 file with
+ .. note:: If you write a dictionary to a HDF5 file with
:func:`dicttoh5` and then read it back with :func:`h5todict`, data
types are not preserved. All values are cast to numpy arrays before
being written to file, and they are read back as numpy arrays (or
@@ -412,28 +422,159 @@ def h5todict(h5file, path="/", exclude_names=None, asarray=True):
a string in this list will be ignored. Default is None (ignore nothing)
:param bool asarray: True (default) to read scalar as arrays, False to
read them as scalar
+ :param bool dereference_links: True (default) to dereference links, False
+ to preserve the link itself
+ :param bool include_attributes: False (default)
+ :param str errors: Handling of errors (HDF5 access issue, broken link,...):
+ - 'raise' (default): Raise an exception
+ - 'log': Log as errors
+ - 'ignore': Ignore errors
:return: Nested dictionary
"""
with _SafeH5FileRead(h5file) as h5f:
ddict = {}
- for key in h5f[path]:
+ if path not in h5f:
+ _handle_error(
+ errors, KeyError, 'Path "%s" does not exist in file.', path)
+ return ddict
+
+ try:
+ root = h5f[path]
+ except KeyError as e:
+ if not isinstance(h5f.get(path, getlink=True), h5py.HardLink):
+ _handle_error(errors,
+ KeyError,
+ 'Cannot retrieve path "%s" (broken link)',
+ path)
+ else:
+ _handle_error(errors, KeyError, ', '.join(e.args))
+ return ddict
+
+ # Read the attributes of the group
+ if include_attributes:
+ attrs = H5pyAttributesReadWrapper(root.attrs)
+ for aname, avalue in attrs.items():
+ ddict[("", aname)] = avalue
+ # Read the children of the group
+ for key in root:
if _name_contains_string_in_list(key, exclude_names):
continue
- if is_group(h5f[path + "/" + key]):
+ h5name = path + "/" + key
+ # Preserve HDF5 link when requested
+ if not dereference_links:
+ lnk = h5f.get(h5name, getlink=True)
+ if is_link(lnk):
+ ddict[key] = lnk
+ continue
+
+ try:
+ h5obj = h5f[h5name]
+ except KeyError as e:
+ if not isinstance(h5f.get(h5name, getlink=True), h5py.HardLink):
+ _handle_error(errors,
+ KeyError,
+ 'Cannot retrieve path "%s" (broken link)',
+ h5name)
+ else:
+ _handle_error(errors, KeyError, ', '.join(e.args))
+ continue
+
+ if is_group(h5obj):
+ # Child is an HDF5 group
ddict[key] = h5todict(h5f,
- path + "/" + key,
+ h5name,
exclude_names=exclude_names,
- asarray=asarray)
+ asarray=asarray,
+ dereference_links=dereference_links,
+ include_attributes=include_attributes)
else:
- # Read HDF5 datset
- data = h5f[path + "/" + key][()]
- if asarray: # Convert HDF5 dataset to numpy array
- data = numpy.array(data, copy=False)
- ddict[key] = data
-
+ # Child is an HDF5 dataset
+ try:
+ data = h5py_read_dataset(h5obj)
+ except OSError:
+ _handle_error(errors,
+ OSError,
+ 'Cannot retrieve dataset "%s"',
+ h5name)
+ else:
+ if asarray: # Convert HDF5 dataset to numpy array
+ data = numpy.array(data, copy=False)
+ ddict[key] = data
+ # Read the attributes of the child
+ if include_attributes:
+ attrs = H5pyAttributesReadWrapper(h5obj.attrs)
+ for aname, avalue in attrs.items():
+ ddict[(key, aname)] = avalue
return ddict
+def dicttonx(treedict, h5file, h5path="/", **kw):
+ """
+ Write a nested dictionary to a HDF5 file, using string keys as member names.
+ The NeXus convention is used to identify attributes with ``"@"`` character,
+ therefore the dataset_names should not contain ``"@"``.
+
+ Similarly, links are identified by keys starting with the ``">"`` character.
+ The corresponding value can be a soft or external link.
+
+ :param treedict: Nested dictionary/tree structure with strings as keys
+ and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub tree. The ``"@"`` character is used to write attributes.
+ The ``">"`` prefix is used to define links.
+
+ The named parameters are passed to dicttoh5.
+
+ Example::
+
+ import numpy
+ from silx.io.dictdump import dicttonx
+
+ gauss = {
+ "entry":{
+ "title":u"A plot of a gaussian",
+ "instrument": {
+ "@NX_class": u"NXinstrument",
+ "positioners": {
+ "@NX_class": u"NXCollection",
+ "x": numpy.arange(0,1.1,.1)
+ }
+ }
+ "plot": {
+ "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1.,
+ 0.9, 0.66, 0.39, 0.19, 0.08]),
+ ">x": "../instrument/positioners/x",
+ "@signal": "y",
+ "@axes": "x",
+ "@NX_class":u"NXdata",
+ "title:u"Gauss Plot",
+ },
+ "@NX_class": u"NXentry",
+ "default":"plot",
+ }
+ "@NX_class": u"NXroot",
+ "@default": "entry",
+ }
+
+ dicttonx(gauss,"test.h5")
+ """
+ parents = tuple(p for p in h5path.split("/") if p)
+ nxtreedict = nexus_to_h5_dict(treedict, parents=parents)
+ dicttoh5(nxtreedict, h5file, h5path=h5path, **kw)
+
+
+def nxtodict(h5file, **kw):
+ """Read a HDF5 file and return a nested dictionary with the complete file
+ structure and all data.
+
+ As opposed to h5todict, all keys will be strings and no h5py objects are
+ present in the tree.
+
+ The named parameters are passed to h5todict.
+ """
+ nxtreedict = h5todict(h5file, **kw)
+ return h5_to_nexus_dict(nxtreedict)
+
+
def dicttojson(ddict, jsonfile, indent=None, mode="w"):
"""Serialize ``ddict`` as a JSON formatted stream to ``jsonfile``.
diff --git a/silx/io/fabioh5.py b/silx/io/fabioh5.py
index cfaa0a0..2fd719d 100755
--- a/silx/io/fabioh5.py
+++ b/silx/io/fabioh5.py
@@ -1,6 +1,6 @@
# coding: utf-8
# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -656,13 +656,13 @@ class FabioReader(object):
elif result_type.kind == "U":
none_value = u""
elif result_type.kind == "f":
- none_value = numpy.float("NaN")
+ none_value = numpy.float64("NaN")
elif result_type.kind == "i":
- none_value = numpy.int(0)
+ none_value = numpy.int64(0)
elif result_type.kind == "u":
- none_value = numpy.int(0)
+ none_value = numpy.int64(0)
elif result_type.kind == "b":
- none_value = numpy.bool(False)
+ none_value = numpy.bool_(False)
else:
none_value = None
diff --git a/silx/io/nxdata/parse.py b/silx/io/nxdata/parse.py
index 6bd18d6..b1c1bba 100644
--- a/silx/io/nxdata/parse.py
+++ b/silx/io/nxdata/parse.py
@@ -45,7 +45,7 @@ import json
import numpy
import six
-from silx.io.utils import is_group, is_file, is_dataset
+from silx.io.utils import is_group, is_file, is_dataset, h5py_read_dataset
from ._utils import get_attr_as_unicode, INTERPDIM, nxdata_logger, \
get_uncertainties_names, get_signal_name, \
@@ -628,7 +628,7 @@ class NXdata(object):
data_dataset_names = [self.signal_name] + self.axes_dataset_names
if (title is not None and is_dataset(title) and
"title" not in data_dataset_names):
- return str(title[()])
+ return str(h5py_read_dataset(title))
title = self.group.attrs.get("title")
if title is None:
diff --git a/silx/io/setup.py b/silx/io/setup.py
index 4aaf324..9cafa17 100644
--- a/silx/io/setup.py
+++ b/silx/io/setup.py
@@ -51,7 +51,7 @@ else:
SPECFILE_USE_GNU_SOURCE = int(SPECFILE_USE_GNU_SOURCE)
if sys.platform == "win32":
- define_macros = [('WIN32', None)]
+ define_macros = [('WIN32', None), ('SPECFILE_POSIX', None)]
elif os.name.lower().startswith('posix'):
define_macros = [('SPECFILE_POSIX', None)]
# the best choice is to have _GNU_SOURCE defined
diff --git a/silx/io/specfile/src/locale_management.c b/silx/io/specfile/src/locale_management.c
index 54695f5..0c5f7ca 100644
--- a/silx/io/specfile/src/locale_management.c
+++ b/silx/io/specfile/src/locale_management.c
@@ -39,6 +39,9 @@
# else
# ifdef SPECFILE_POSIX
# include <locale.h>
+# ifndef LOCALE_NAME_MAX_LENGTH
+# define LOCALE_NAME_MAX_LENGTH 85
+# endif
# endif
# endif
#endif
@@ -60,7 +63,7 @@ double PyMcaAtof(const char * inputString)
#else
#ifdef SPECFILE_POSIX
char *currentLocaleBuffer;
- char localeBuffer[21];
+ char localeBuffer[LOCALE_NAME_MAX_LENGTH + 1] = {'\0'};
double result;
currentLocaleBuffer = setlocale(LC_NUMERIC, NULL);
strcpy(localeBuffer, currentLocaleBuffer);
diff --git a/silx/io/test/test_dictdump.py b/silx/io/test/test_dictdump.py
index c0b6914..b99116b 100644
--- a/silx/io/test/test_dictdump.py
+++ b/silx/io/test/test_dictdump.py
@@ -43,6 +43,8 @@ from .. import dictdump
from ..dictdump import dicttoh5, dicttojson, dump
from ..dictdump import h5todict, load
from ..dictdump import logger as dictdump_logger
+from ..utils import is_link
+from ..utils import h5py_read_dataset
def tree():
@@ -58,15 +60,29 @@ city_attrs["Europe"]["France"]["Grenoble"]["inhabitants"] = inhabitants
city_attrs["Europe"]["France"]["Grenoble"]["coordinates"] = [45.1830, 5.7196]
city_attrs["Europe"]["France"]["Tourcoing"]["area"]
+ext_attrs = tree()
+ext_attrs["ext_group"]["dataset"] = 10
+ext_filename = "ext.h5"
+
+link_attrs = tree()
+link_attrs["links"]["group"]["dataset"] = 10
+link_attrs["links"]["group"]["relative_softlink"] = h5py.SoftLink("dataset")
+link_attrs["links"]["relative_softlink"] = h5py.SoftLink("group/dataset")
+link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset")
+link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset")
+
class TestDictToH5(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
def tearDown(self):
if os.path.exists(self.h5_fname):
os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
os.rmdir(self.tempdir)
def testH5CityAttrs(self):
@@ -201,31 +217,129 @@ class TestDictToH5(unittest.TestCase):
self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
self.assertEqual(h5file["group/group"].attrs['attr'], 12)
+ def testLinks(self):
+ with h5py.File(self.h5_ext_fname, "w") as h5file:
+ dictdump.dicttoh5(ext_attrs, h5file)
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(link_attrs, h5file)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["links/group/dataset"][()], 10)
+ self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/absolute_softlink"][()], 10)
+ self.assertEqual(h5file["links/external_link"][()], 10)
+
+ def testDumpNumpyArray(self):
+ ddict = {
+ 'darks': {
+ '0': numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16)
+ }
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]),
+ ddict['darks']['0'])
+
+
+class TestH5ToDict(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
+ dicttoh5(city_attrs, self.h5_fname)
+ dicttoh5(link_attrs, self.h5_fname, mode="a")
+ dicttoh5(ext_attrs, self.h5_ext_fname)
+
+ def tearDown(self):
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
+ os.rmdir(self.tempdir)
+
+ def testExcludeNames(self):
+ ddict = h5todict(self.h5_fname, path="/Europe/France",
+ exclude_names=["ourcoing", "inhab", "toto"])
+ self.assertNotIn("Tourcoing", ddict)
+ self.assertIn("Grenoble", ddict)
+
+ self.assertNotIn("inhabitants", ddict["Grenoble"])
+ self.assertIn("coordinates", ddict["Grenoble"])
+ self.assertIn("area", ddict["Grenoble"])
+
+ def testAsArrayTrue(self):
+ """Test with asarray=True, the default"""
+ ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble")
+ self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants)))
+
+ def testAsArrayFalse(self):
+ """Test with asarray=False"""
+ ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False)
+ self.assertEqual(ddict["inhabitants"], inhabitants)
+
+ def testDereferenceLinks(self):
+ ddict = h5todict(self.h5_fname, path="links", dereference_links=True)
+ self.assertTrue(ddict["absolute_softlink"], 10)
+ self.assertTrue(ddict["relative_softlink"], 10)
+ self.assertTrue(ddict["external_link"], 10)
+ self.assertTrue(ddict["group"]["relative_softlink"], 10)
+
+ def testPreserveLinks(self):
+ ddict = h5todict(self.h5_fname, path="links", dereference_links=False)
+ self.assertTrue(is_link(ddict["absolute_softlink"]))
+ self.assertTrue(is_link(ddict["relative_softlink"]))
+ self.assertTrue(is_link(ddict["external_link"]))
+ self.assertTrue(is_link(ddict["group"]["relative_softlink"]))
+
+ def testStrings(self):
+ ddict = {"dset_bytes": b"bytes",
+ "dset_utf8": "utf8",
+ "dset_2bytes": [b"bytes", b"bytes"],
+ "dset_2utf8": ["utf8", "utf8"],
+ ("", "attr_bytes"): b"bytes",
+ ("", "attr_utf8"): "utf8",
+ ("", "attr_2bytes"): [b"bytes", b"bytes"],
+ ("", "attr_2utf8"): ["utf8", "utf8"]}
+ dicttoh5(ddict, self.h5_fname, mode="w")
+ adict = h5todict(self.h5_fname, include_attributes=True, asarray=False)
+ self.assertEqual(ddict["dset_bytes"], adict["dset_bytes"])
+ self.assertEqual(ddict["dset_utf8"], adict["dset_utf8"])
+ self.assertEqual(ddict[("", "attr_bytes")], adict[("", "attr_bytes")])
+ self.assertEqual(ddict[("", "attr_utf8")], adict[("", "attr_utf8")])
+ numpy.testing.assert_array_equal(ddict["dset_2bytes"], adict["dset_2bytes"])
+ numpy.testing.assert_array_equal(ddict["dset_2utf8"], adict["dset_2utf8"])
+ numpy.testing.assert_array_equal(ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")])
+ numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")])
+
class TestDictToNx(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "nx.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
def tearDown(self):
if os.path.exists(self.h5_fname):
os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
os.rmdir(self.tempdir)
def testAttributes(self):
"""Any kind of attribute can be described"""
ddict = {
- "group": {"datatset": "hmmm", "@group_attr": 10},
- "dataset": "aaaaaaaaaaaaaaa",
+ "group": {"dataset": 100, "@group_attr1": 10},
+ "dataset": 200,
"@root_attr": 11,
- "dataset@dataset_attr": 12,
+ "dataset@dataset_attr": "12",
"group@group_attr2": 13,
}
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttonx(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['group_attr'], 10)
+ self.assertEqual(h5file["group"].attrs['group_attr1'], 10)
self.assertEqual(h5file.attrs['root_attr'], 11)
- self.assertEqual(h5file["dataset"].attrs['dataset_attr'], 12)
+ self.assertEqual(h5file["dataset"].attrs['dataset_attr'], "12")
self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
def testKeyOrder(self):
@@ -280,36 +394,120 @@ class TestDictToNx(unittest.TestCase):
self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
self.assertEqual(h5file["group/group"].attrs['attr'], 12)
-
-class TestH5ToDict(unittest.TestCase):
+ def testLinks(self):
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["links/group/dataset"][()], 10)
+ self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/absolute_softlink"][()], 10)
+ self.assertEqual(h5file["links/external_link"][()], 10)
+
+ def testUpLinks(self):
+ ddict = {"data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}},
+ "links": {"group": {"subgroup": {">relative_softlink": "../../../data/group/dataset"}}}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10)
+
+
+class TestNxToDict(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
- dicttoh5(city_attrs, self.h5_fname)
+ self.h5_fname = os.path.join(self.tempdir, "nx.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
def tearDown(self):
- os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
os.rmdir(self.tempdir)
- def testExcludeNames(self):
- ddict = h5todict(self.h5_fname, path="/Europe/France",
- exclude_names=["ourcoing", "inhab", "toto"])
- self.assertNotIn("Tourcoing", ddict)
- self.assertIn("Grenoble", ddict)
-
- self.assertNotIn("inhabitants", ddict["Grenoble"])
- self.assertIn("coordinates", ddict["Grenoble"])
- self.assertIn("area", ddict["Grenoble"])
-
- def testAsArrayTrue(self):
- """Test with asarray=True, the default"""
- ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble")
- self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants)))
-
- def testAsArrayFalse(self):
- """Test with asarray=False"""
- ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False)
- self.assertEqual(ddict["inhabitants"], inhabitants)
+ def testAttributes(self):
+ """Any kind of attribute can be described"""
+ ddict = {
+ "group": {"dataset": 100, "@group_attr1": 10},
+ "dataset": 200,
+ "@root_attr": 11,
+ "dataset@dataset_attr": "12",
+ "group@group_attr2": 13,
+ }
+ dictdump.dicttonx(ddict, self.h5_fname)
+ ddict = dictdump.nxtodict(self.h5_fname, include_attributes=True)
+ self.assertEqual(ddict["group"]["@group_attr1"], 10)
+ self.assertEqual(ddict["@root_attr"], 11)
+ self.assertEqual(ddict["dataset@dataset_attr"], "12")
+ self.assertEqual(ddict["group"]["@group_attr2"], 13)
+
+ def testDereferenceLinks(self):
+ """Write links and dereference on read"""
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+
+ ddict = dictdump.h5todict(self.h5_fname, dereference_links=True)
+ self.assertTrue(ddict["links"]["absolute_softlink"], 10)
+ self.assertTrue(ddict["links"]["relative_softlink"], 10)
+ self.assertTrue(ddict["links"]["external_link"], 10)
+ self.assertTrue(ddict["links"]["group"]["relative_softlink"], 10)
+
+ def testPreserveLinks(self):
+ """Write/read links"""
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+
+ ddict = dictdump.nxtodict(self.h5_fname, dereference_links=False)
+ self.assertTrue(ddict["links"][">absolute_softlink"], "dataset")
+ self.assertTrue(ddict["links"][">relative_softlink"], "group/dataset")
+ self.assertTrue(ddict["links"][">external_link"], "/links/group/dataset")
+ self.assertTrue(ddict["links"]["group"][">relative_softlink"], "nx_ext.h5::/ext_group/datase")
+
+ def testNotExistingPath(self):
+ """Test converting not existing path"""
+ with h5py.File(self.h5_fname, 'a') as f:
+ f['data'] = 1
+
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='ignore')
+ self.assertFalse(ddict)
+
+ with TestLogging(dictdump_logger, error=1):
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='log')
+ self.assertFalse(ddict)
+
+ with self.assertRaises(KeyError):
+ h5todict(self.h5_fname, path="/I/am/not/a/path", errors='raise')
+
+ def testBrokenLinks(self):
+ """Test with broken links"""
+ with h5py.File(self.h5_fname, 'a') as f:
+ f["/Mars/BrokenSoftLink"] = h5py.SoftLink("/Idontexists")
+ f["/Mars/BrokenExternalLink"] = h5py.ExternalLink("notexistingfile.h5", "/Idontexists")
+
+ ddict = h5todict(self.h5_fname, path="/Mars", errors='ignore')
+ self.assertFalse(ddict)
+
+ with TestLogging(dictdump_logger, error=2):
+ ddict = h5todict(self.h5_fname, path="/Mars", errors='log')
+ self.assertFalse(ddict)
+
+ with self.assertRaises(KeyError):
+ h5todict(self.h5_fname, path="/Mars", errors='raise')
class TestDictToJson(unittest.TestCase):
@@ -436,6 +634,7 @@ def suite():
test_suite.addTest(loadTests(TestDictToNx))
test_suite.addTest(loadTests(TestDictToJson))
test_suite.addTest(loadTests(TestH5ToDict))
+ test_suite.addTest(loadTests(TestNxToDict))
return test_suite
diff --git a/silx/io/test/test_spectoh5.py b/silx/io/test/test_spectoh5.py
index c3f03e9..903a62c 100644
--- a/silx/io/test/test_spectoh5.py
+++ b/silx/io/test/test_spectoh5.py
@@ -33,6 +33,7 @@ import h5py
from ..spech5 import SpecH5, SpecH5Group
from ..convert import convert, write_to_h5
+from ..utils import h5py_read_dataset
__authors__ = ["P. Knobel"]
__license__ = "MIT"
@@ -129,7 +130,7 @@ class TestConvertSpecHDF5(unittest.TestCase):
def testTitle(self):
"""Test the value of a dataset"""
- title12 = self.h5f["/1.2/title"][()]
+ title12 = h5py_read_dataset(self.h5f["/1.2/title"])
self.assertEqual(title12,
u"aaaaaa")
diff --git a/silx/io/test/test_url.py b/silx/io/test/test_url.py
index e68c67a..114f6a7 100644
--- a/silx/io/test/test_url.py
+++ b/silx/io/test/test_url.py
@@ -152,6 +152,16 @@ class TestDataUrl(unittest.TestCase):
expected = [True, True, None, "/a.h5", "/b", (5, 1)]
self.assertUrl(url, expected)
+ def test_slice2(self):
+ url = DataUrl("/a.h5?path=/b&slice=2:5")
+ expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)]
+ self.assertUrl(url, expected)
+
+ def test_slice3(self):
+ url = DataUrl("/a.h5?path=/b&slice=::2")
+ expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)]
+ self.assertUrl(url, expected)
+
def test_slice_ellipsis(self):
url = DataUrl("/a.h5?path=/b&slice=...")
expected = [True, True, None, "/a.h5", "/b", (Ellipsis, )]
diff --git a/silx/io/test/test_utils.py b/silx/io/test/test_utils.py
index 6c70636..13ab532 100644
--- a/silx/io/test/test_utils.py
+++ b/silx/io/test/test_utils.py
@@ -33,6 +33,7 @@ import unittest
import sys
from .. import utils
+from ..._version import calc_hexversion
import silx.io.url
import h5py
@@ -40,11 +41,9 @@ from ..utils import h5ls
import fabio
-
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "12/02/2018"
-
+__date__ = "03/12/2020"
expected_spec1 = r"""#F .*
#D .*
@@ -67,6 +66,28 @@ expected_spec2 = expected_spec1 + r"""
2 8\.00
3 9\.00
"""
+
+expected_spec2reg = r"""#F .*
+#D .*
+
+#S 1 Ordinate1
+#D .*
+#N 3
+#L Abscissa Ordinate1 Ordinate2
+1 4\.00 7\.00
+2 5\.00 8\.00
+3 6\.00 9\.00
+"""
+
+expected_spec2irr = expected_spec1 + r"""
+#S 2 Ordinate2
+#D .*
+#N 2
+#L Abscissa Ordinate2
+1 7\.00
+2 8\.00
+"""
+
expected_csv = r"""Abscissa;Ordinate1;Ordinate2
1;4\.00;7\.00e\+00
2;5\.00;8\.00e\+00
@@ -83,6 +104,7 @@ expected_csv2 = r"""x;y0;y1
class TestSave(unittest.TestCase):
"""Test saving curves as SpecFile:
"""
+
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.spec_fname = os.path.join(self.tempdir, "savespec.dat")
@@ -92,6 +114,7 @@ class TestSave(unittest.TestCase):
self.x = [1, 2, 3]
self.xlab = "Abscissa"
self.y = [[4, 5, 6], [7, 8, 9]]
+ self.y_irr = [[4, 5, 6], [7, 8]]
self.ylabs = ["Ordinate1", "Ordinate2"]
def tearDown(self):
@@ -103,13 +126,6 @@ class TestSave(unittest.TestCase):
os.unlink(self.npy_fname)
shutil.rmtree(self.tempdir)
- def assertRegex(self, *args, **kwargs):
- # Python 2 compatibility
- if sys.version_info.major >= 3:
- return super(TestSave, self).assertRegex(*args, **kwargs)
- else:
- return self.assertRegexpMatches(*args, **kwargs)
-
def test_save_csv(self):
utils.save1D(self.csv_fname, self.x, self.y,
xlabel=self.xlab, ylabels=self.ylabs,
@@ -145,7 +161,6 @@ class TestSave(unittest.TestCase):
specf = open(self.spec_fname)
actual_spec = specf.read()
specf.close()
-
self.assertRegex(actual_spec, expected_spec1)
def test_savespec_file_handle(self):
@@ -165,18 +180,30 @@ class TestSave(unittest.TestCase):
specf = open(self.spec_fname)
actual_spec = specf.read()
specf.close()
-
self.assertRegex(actual_spec, expected_spec2)
- def test_save_spec(self):
- """Save SpecFile using save()"""
+ def test_save_spec_reg(self):
+ """Save SpecFile using save() on a regular pattern"""
utils.save1D(self.spec_fname, self.x, self.y, xlabel=self.xlab,
ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
specf = open(self.spec_fname)
actual_spec = specf.read()
specf.close()
- self.assertRegex(actual_spec, expected_spec2)
+
+ self.assertRegex(actual_spec, expected_spec2reg)
+
+ def test_save_spec_irr(self):
+ """Save SpecFile using save() on an irregular pattern"""
+ # invalid test case ?!
+ return
+ utils.save1D(self.spec_fname, self.x, self.y_irr, xlabel=self.xlab,
+ ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
+
+ specf = open(self.spec_fname)
+ actual_spec = specf.read()
+ specf.close()
+ self.assertRegex(actual_spec, expected_spec2irr)
def test_save_csv_no_labels(self):
"""Save csv using save(), with autoheader=True but
@@ -217,6 +244,7 @@ class TestH5Ls(unittest.TestCase):
<HDF5 dataset "data": shape (1,), type "<f8">
"""
+
def assertMatchAnyStringInList(self, pattern, list_of_strings):
for string_ in list_of_strings:
if re.match(pattern, string_):
@@ -395,6 +423,7 @@ class TestOpen(unittest.TestCase):
class TestNodes(unittest.TestCase):
"""Test `silx.io.utils.is_` functions."""
+
def test_real_h5py_objects(self):
name = tempfile.mktemp(suffix=".h5")
try:
@@ -417,45 +446,60 @@ class TestNodes(unittest.TestCase):
os.unlink(name)
def test_h5py_like_file(self):
+
class Foo(object):
+
def __init__(self):
self.h5_class = utils.H5Type.FILE
+
obj = Foo()
self.assertTrue(utils.is_file(obj))
self.assertTrue(utils.is_group(obj))
self.assertFalse(utils.is_dataset(obj))
def test_h5py_like_group(self):
+
class Foo(object):
+
def __init__(self):
self.h5_class = utils.H5Type.GROUP
+
obj = Foo()
self.assertFalse(utils.is_file(obj))
self.assertTrue(utils.is_group(obj))
self.assertFalse(utils.is_dataset(obj))
def test_h5py_like_dataset(self):
+
class Foo(object):
+
def __init__(self):
self.h5_class = utils.H5Type.DATASET
+
obj = Foo()
self.assertFalse(utils.is_file(obj))
self.assertFalse(utils.is_group(obj))
self.assertTrue(utils.is_dataset(obj))
def test_bad(self):
+
class Foo(object):
+
def __init__(self):
pass
+
obj = Foo()
self.assertFalse(utils.is_file(obj))
self.assertFalse(utils.is_group(obj))
self.assertFalse(utils.is_dataset(obj))
def test_bad_api(self):
+
class Foo(object):
+
def __init__(self):
self.h5_class = int
+
obj = Foo()
self.assertFalse(utils.is_file(obj))
self.assertFalse(utils.is_group(obj))
@@ -513,18 +557,20 @@ class TestGetData(unittest.TestCase):
def test_hdf5_array(self):
url = "silx:%s?/group/group/array" % self.h5_filename
data = utils.get_data(url=url)
- self.assertEqual(data.shape, (5, ))
+ self.assertEqual(data.shape, (5,))
self.assertEqual(data[0], 1)
def test_hdf5_array_slice(self):
url = "silx:%s?path=/group/group/array2d&slice=1" % self.h5_filename
data = utils.get_data(url=url)
- self.assertEqual(data.shape, (5, ))
+ self.assertEqual(data.shape, (5,))
self.assertEqual(data[0], 6)
def test_hdf5_array_slice_out_of_range(self):
url = "silx:%s?path=/group/group/array2d&slice=5" % self.h5_filename
- self.assertRaises(ValueError, utils.get_data, url)
+ # ValueError: h5py 2.x
+ # IndexError: h5py 3.x
+ self.assertRaises((ValueError, IndexError), utils.get_data, url)
def test_edf_using_silx(self):
url = "silx:%s?/scan_0/instrument/detector_0/data" % self.edf_filename
@@ -568,14 +614,15 @@ class TestGetData(unittest.TestCase):
def _h5_py_version_older_than(version):
- v_majeur, v_mineur, v_micro = h5py.version.version.split('.')[:3]
- r_majeur, r_mineur, r_micro = version.split('.')
- return v_majeur >= r_majeur and v_mineur >= r_mineur
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
+ r_majeur, r_mineur, r_micro = [int(i) for i in version.split('.')]
+ return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(r_majeur, r_mineur, r_micro)
@unittest.skipUnless(_h5_py_version_older_than('2.9.0'), 'h5py version < 2.9.0')
class TestRawFileToH5(unittest.TestCase):
"""Test conversion of .vol file to .h5 external dataset"""
+
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self._vol_file = os.path.join(self.tempdir, 'test_vol.vol')
@@ -589,7 +636,7 @@ class TestRawFileToH5(unittest.TestCase):
assert os.path.exists(self._vol_file + '.npy')
os.rename(self._vol_file + '.npy', self._vol_file)
self.h5_file = os.path.join(self.tempdir, 'test_h5.h5')
- self.external_dataset_path= '/root/my_external_dataset'
+ self.external_dataset_path = '/root/my_external_dataset'
self._data_url = silx.io.url.DataUrl(file_path=self.h5_file,
data_path=self.external_dataset_path)
with open(self._file_info, 'w') as _fi:
@@ -672,6 +719,158 @@ class TestRawFileToH5(unittest.TestCase):
shape=self._dataset_shape))
+class TestH5Strings(unittest.TestCase):
+ """Test HDF5 str and bytes writing and reading"""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.vlenstr = h5py.special_dtype(vlen=str)
+ cls.vlenbytes = h5py.special_dtype(vlen=bytes)
+ try:
+ cls.unicode = unicode
+ except NameError:
+ cls.unicode = str
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir)
+
+ def setUp(self):
+ self.file = h5py.File(os.path.join(self.tempdir, 'file.h5'), mode="w")
+
+ def tearDown(self):
+ self.file.close()
+
+ @classmethod
+ def _make_array(cls, value, n):
+ if isinstance(value, bytes):
+ dtype = cls.vlenbytes
+ elif isinstance(value, cls.unicode):
+ dtype = cls.vlenstr
+ else:
+ return numpy.array([value] * n)
+ return numpy.array([value] * n, dtype=dtype)
+
+ @classmethod
+ def _get_charset(cls, value):
+ if isinstance(value, bytes):
+ return h5py.h5t.CSET_ASCII
+ elif isinstance(value, cls.unicode):
+ return h5py.h5t.CSET_UTF8
+ else:
+ return None
+
+ def _check_dataset(self, value, result=None):
+ # Write+read scalar
+ if result:
+ decode_ascii = True
+ else:
+ decode_ascii = False
+ result = value
+ charset = self._get_charset(value)
+ self.file["data"] = value
+ data = utils.h5py_read_dataset(self.file["data"], decode_ascii=decode_ascii)
+ assert type(data) == type(result), data
+ assert data == result, data
+ if charset:
+ assert self.file["data"].id.get_type().get_cset() == charset
+
+ # Write+read variable length
+ self.file["vlen_data"] = self._make_array(value, 2)
+ data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii, index=0)
+ assert type(data) == type(result), data
+ assert data == result, data
+ data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii)
+ numpy.testing.assert_array_equal(data, [result] * 2)
+ if charset:
+ assert self.file["vlen_data"].id.get_type().get_cset() == charset
+
+ def _check_attribute(self, value, result=None):
+ if result:
+ decode_ascii = True
+ else:
+ decode_ascii = False
+ result = value
+ self.file.attrs["data"] = value
+ data = utils.h5py_read_attribute(self.file.attrs, "data", decode_ascii=decode_ascii)
+ assert type(data) == type(result), data
+ assert data == result, data
+
+ self.file.attrs["vlen_data"] = self._make_array(value, 2)
+ data = utils.h5py_read_attribute(self.file.attrs, "vlen_data", decode_ascii=decode_ascii)
+ assert type(data[0]) == type(result), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)["vlen_data"]
+ assert type(data[0]) == type(result), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ def test_dataset_ascii_bytes(self):
+ self._check_dataset(b"abc")
+
+ def test_attribute_ascii_bytes(self):
+ self._check_attribute(b"abc")
+
+ def test_dataset_ascii_bytes_decode(self):
+ self._check_dataset(b"abc", result="abc")
+
+ def test_attribute_ascii_bytes_decode(self):
+ self._check_attribute(b"abc", result="abc")
+
+ def test_dataset_ascii_str(self):
+ self._check_dataset("abc")
+
+ def test_attribute_ascii_str(self):
+ self._check_attribute("abc")
+
+ def test_dataset_utf8_str(self):
+ self._check_dataset("\u0101bc")
+
+ def test_attribute_utf8_str(self):
+ self._check_attribute("\u0101bc")
+
+ def test_dataset_utf8_bytes(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_dataset(b"\xc4\x81bc")
+
+ def test_attribute_utf8_bytes(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_attribute(b"\xc4\x81bc")
+
+ def test_dataset_utf8_bytes_decode(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_dataset(b"\xc4\x81bc", result="\u0101bc")
+
+ def test_attribute_utf8_bytes_decode(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_attribute(b"\xc4\x81bc", result="\u0101bc")
+
+ def test_dataset_latin1_bytes(self):
+ # extended ascii character 0xE4
+ self._check_dataset(b"\xe423")
+
+ def test_attribute_latin1_bytes(self):
+ # extended ascii character 0xE4
+ self._check_attribute(b"\xe423")
+
+ def test_dataset_latin1_bytes_decode(self):
+ # U+DCE4: surrogate for extended ascii character 0xE4
+ self._check_dataset(b"\xe423", result="\udce423")
+
+ def test_attribute_latin1_bytes_decode(self):
+ # U+DCE4: surrogate for extended ascii character 0xE4
+ self._check_attribute(b"\xe423", result="\udce423")
+
+ def test_dataset_no_string(self):
+ self._check_dataset(numpy.int64(10))
+
+ def test_attribute_no_string(self):
+ self._check_attribute(numpy.int64(10))
+
+
def suite():
loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
test_suite = unittest.TestSuite()
@@ -681,6 +880,7 @@ def suite():
test_suite.addTest(loadTests(TestNodes))
test_suite.addTest(loadTests(TestGetData))
test_suite.addTest(loadTests(TestRawFileToH5))
+ test_suite.addTest(loadTests(TestH5Strings))
return test_suite
diff --git a/silx/io/url.py b/silx/io/url.py
index 7607ae5..044977c 100644
--- a/silx/io/url.py
+++ b/silx/io/url.py
@@ -178,8 +178,20 @@ class DataUrl(object):
def str_to_slice(string):
if string == "...":
return Ellipsis
- elif string == ":":
- return slice(None)
+ elif ':' in string:
+ if string == ":":
+ return slice(None)
+ else:
+ def get_value(my_str):
+ if my_str in ('', None):
+ return None
+ else:
+ return int(my_str)
+ sss = string.split(':')
+ start = get_value(sss[0])
+ stop = get_value(sss[1] if len(sss) > 1 else None)
+ step = get_value(sss[2] if len(sss) > 2 else None)
+ return slice(start, stop, step)
else:
return int(string)
@@ -201,7 +213,10 @@ class DataUrl(object):
:param str path: Path representing the URL.
"""
self.__path = path
- path = path.replace("::", "?", 1)
+ # only replace if ? not here already. Otherwise can mess sith
+ # data_slice if == ::2 for example
+ if '?' not in path:
+ path = path.replace("::", "?", 1)
url = parse.urlparse(path)
is_valid = True
diff --git a/silx/io/utils.py b/silx/io/utils.py
index 5da344d..12e9a7e 100644
--- a/silx/io/utils.py
+++ b/silx/io/utils.py
@@ -25,8 +25,7 @@
__authors__ = ["P. Knobel", "V. Valls"]
__license__ = "MIT"
-__date__ = "18/04/2018"
-
+__date__ = "03/12/2020"
import enum
import os.path
@@ -40,18 +39,19 @@ import six
from silx.utils.proxy import Proxy
import silx.io.url
+from .._version import calc_hexversion
import h5py
+import h5py.h5t
+import h5py.h5a
try:
import h5pyd
except ImportError as e:
h5pyd = None
-
logger = logging.getLogger(__name__)
-
NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"]
"""List of possible extensions for HDF5 file formats."""
@@ -190,34 +190,46 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
if xlabel is None:
xlabel = "x"
if ylabels is None:
- if len(numpy.array(y).shape) > 1:
+ if numpy.array(y).ndim > 1:
ylabels = ["y%d" % i for i in range(len(y))]
else:
ylabels = ["y"]
elif isinstance(ylabels, (list, tuple)):
# if ylabels is provided as a list, every element must
# be a string
- ylabels = [ylabels[i] if ylabels[i] is not None else "y%d" % i
- for i in range(len(ylabels))]
+ ylabels = [ylabel if isinstance(ylabel, string_types) else "y%d" % i
+ for ylabel in ylabels]
if filetype.lower() == "spec":
- y_array = numpy.asarray(y)
-
- # make sure y_array is a 2D array even for a single curve
- if len(y_array.shape) == 1:
- y_array = y_array.reshape(1, y_array.shape[0])
- elif len(y_array.shape) > 2 or len(y_array.shape) < 1:
- raise IndexError("y must be a 1D or 2D array")
-
- # First curve
- specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt,
- scan_number=1, mode="w", write_file_header=True,
- close_file=False)
- # Other curves
- for i in range(1, y_array.shape[0]):
- specf = savespec(specf, x, y_array[i], xlabel, ylabels[i],
- fmt=fmt, scan_number=i + 1, mode="w",
- write_file_header=False, close_file=False)
+ # Check if we have regular data:
+ ref = len(x)
+ regular = True
+ for one_y in y:
+ regular &= len(one_y) == ref
+ if regular:
+ if isinstance(fmt, (list, tuple)) and len(fmt) < (len(ylabels) + 1):
+ fmt = fmt + [fmt[-1] * (1 + len(ylabels) - len(fmt))]
+ specf = savespec(fname, x, y, xlabel, ylabels, fmt=fmt,
+ scan_number=1, mode="w", write_file_header=True,
+ close_file=False)
+ else:
+ y_array = numpy.asarray(y)
+ # make sure y_array is a 2D array even for a single curve
+ if y_array.ndim == 1:
+ y_array.shape = 1, -1
+ elif y_array.ndim not in [1, 2]:
+ raise IndexError("y must be a 1D or 2D array")
+
+ # First curve
+ specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt,
+ scan_number=1, mode="w", write_file_header=True,
+ close_file=False)
+ # Other curves
+ for i in range(1, y_array.shape[0]):
+ specf = savespec(specf, x, y_array[i], xlabel, ylabels[i],
+ fmt=fmt, scan_number=i + 1, mode="w",
+ write_file_header=False, close_file=False)
+
# close file if we created it
if not hasattr(fname, "write"):
specf.close()
@@ -307,9 +319,11 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
or append mode. If a file name is provided, a new file is open in
write mode (existing file with the same name will be lost)
:param x: 1D-Array (or list) of abscissa values
- :param y: 1D-array (or list) of ordinates values
+ :param y: 1D-array (or list), or list of them of ordinates values.
+ All dataset must have the same length as x
:param xlabel: Abscissa label (default ``"X"``)
- :param ylabel: Ordinate label
+ :param ylabel: Ordinate label, may be a list of labels when multiple curves
+ are to be saved together.
:param fmt: Format string for data. You can specify a short format
string that defines a single format for both ``x`` and ``y`` values,
or a list of two different format strings (e.g. ``["%d", "%.7g"]``).
@@ -333,40 +347,51 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
x_array = numpy.asarray(x)
y_array = numpy.asarray(y)
+ if y_array.ndim > 2:
+ raise IndexError("Y columns must have be packed as 1D")
- if y_array.shape[0] != x_array.shape[0]:
+ if y_array.shape[-1] != x_array.shape[0]:
raise IndexError("X and Y columns must have the same length")
+ if y_array.ndim == 2:
+ assert isinstance(ylabel, (list, tuple))
+ assert y_array.shape[0] == len(ylabel)
+ labels = (xlabel, *ylabel)
+ else:
+ labels = (xlabel, ylabel)
+ data = numpy.vstack((x_array, y_array))
+ ncol = data.shape[0]
+ assert len(labels) == ncol
+
+ print(xlabel, ylabel, fmt, ncol, x_array, y_array)
if isinstance(fmt, string_types) and fmt.count("%") == 1:
- full_fmt_string = fmt + " " + fmt + "\n"
- elif isinstance(fmt, (list, tuple)) and len(fmt) == 2:
- full_fmt_string = " ".join(fmt) + "\n"
+ full_fmt_string = " ".join([fmt] * ncol)
+ elif isinstance(fmt, (list, tuple)) and len(fmt) == ncol:
+ full_fmt_string = " ".join(fmt)
else:
- raise ValueError("fmt must be a single format string or a list of " +
- "two format strings")
+ raise ValueError("`fmt` must be a single format string or a list of " +
+ "format strings with as many format as ncolumns")
if not hasattr(specfile, "write"):
f = builtin_open(specfile, mode)
else:
f = specfile
- output = ""
-
- current_date = "#D %s\n" % (time.ctime(time.time()))
-
+ current_date = "#D %s" % (time.ctime(time.time()))
if write_file_header:
- output += "#F %s\n" % f.name
- output += current_date
- output += "\n"
-
- output += "#S %d %s\n" % (scan_number, ylabel)
- output += current_date
- output += "#N 2\n"
- output += "#L %s %s\n" % (xlabel, ylabel)
- for i in range(y_array.shape[0]):
- output += full_fmt_string % (x_array[i], y_array[i])
- output += "\n"
+ lines = [ "#F %s" % f.name, current_date, ""]
+ else:
+ lines = [""]
+ lines += [ "#S %d %s" % (scan_number, labels[1]),
+ current_date,
+ "#N %d" % ncol,
+ "#L " + " ".join(labels)]
+
+ for i in data.T:
+ lines.append(full_fmt_string % tuple(i))
+ lines.append("")
+ output = "\n".join(lines)
f.write(output.encode())
if close_file:
@@ -406,7 +431,7 @@ def h5ls(h5group, lvl=0):
if is_group(h5group):
h5f = h5group
elif isinstance(h5group, string_types):
- h5f = open(h5group) # silx.io.open
+ h5f = open(h5group) # silx.io.open
else:
raise TypeError("h5group must be a hdf5-like group object or a file name.")
@@ -735,6 +760,26 @@ def is_softlink(obj):
return t == H5Type.SOFT_LINK
+def is_externallink(obj):
+ """
+ True if the object is a h5py.ExternalLink-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t == H5Type.EXTERNAL_LINK
+
+
+def is_link(obj):
+ """
+ True if the object is a h5py link-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t in {H5Type.SOFT_LINK, H5Type.EXTERNAL_LINK}
+
+
def get_data(url):
"""Returns a numpy data from an URL.
@@ -791,16 +836,16 @@ def get_data(url):
raise ValueError("Data path from URL '%s' is not a dataset" % url.path())
if data_slice is not None:
- data = data[data_slice]
+ data = h5py_read_dataset(data, index=data_slice)
else:
# works for scalar and array
- data = data[()]
+ data = h5py_read_dataset(data)
elif url.scheme() == "fabio":
import fabio
data_slice = url.data_slice()
if data_slice is None:
- data_slice = (0, )
+ data_slice = (0,)
if data_slice is None or len(data_slice) != 1:
raise ValueError("Fabio slice expect a single frame, but %s found" % data_slice)
index = data_slice[0]
@@ -844,8 +889,8 @@ def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype,
"""
assert isinstance(output_url, silx.io.url.DataUrl)
assert isinstance(shape, (tuple, list))
- v_majeur, v_mineur, v_micro = h5py.version.version.split('.')
- if v_majeur <= '2' and v_mineur < '9':
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
+ if calc_hexversion(v_majeu