summaryrefslogtreecommitdiff
path: root/silx
diff options
context:
space:
mode:
Diffstat (limited to 'silx')
-rw-r--r--silx/__main__.py5
-rw-r--r--silx/app/view/test/test_launcher.py74
-rwxr-xr-xsilx/gui/colors.py4
-rw-r--r--silx/gui/data/ArrayTableModel.py78
-rw-r--r--silx/gui/data/DataViews.py6
-rw-r--r--silx/gui/data/NXdataWidgets.py9
-rw-r--r--silx/gui/data/RecordTableView.py42
-rw-r--r--silx/gui/data/test/test_arraywidget.py15
-rw-r--r--silx/gui/dialog/ColormapDialog.py178
-rw-r--r--silx/gui/plot/ColorBar.py18
-rw-r--r--silx/gui/plot/Colormap.py4
-rw-r--r--silx/gui/plot/ImageStack.py43
-rw-r--r--silx/gui/plot/ImageView.py812
-rwxr-xr-xsilx/gui/plot/LegendSelector.py48
-rw-r--r--silx/gui/plot/MaskToolsWidget.py187
-rwxr-xr-xsilx/gui/plot/PlotWidget.py191
-rw-r--r--silx/gui/plot/Profile.py34
-rw-r--r--silx/gui/plot/StatsWidget.py6
-rw-r--r--silx/gui/plot/_BaseMaskToolsWidget.py31
-rwxr-xr-xsilx/gui/plot/actions/control.py17
-rw-r--r--silx/gui/plot/actions/histogram.py366
-rwxr-xr-xsilx/gui/plot/backends/BackendMatplotlib.py12
-rwxr-xr-xsilx/gui/plot/backends/BackendOpenGL.py5
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py17
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotImage.py11
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotItem.py7
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotTriangles.py7
-rw-r--r--silx/gui/plot/items/_arc_roi.py11
-rw-r--r--silx/gui/plot/items/complex.py84
-rw-r--r--silx/gui/plot/items/core.py36
-rw-r--r--silx/gui/plot/items/histogram.py54
-rw-r--r--silx/gui/plot/items/image.py125
-rw-r--r--silx/gui/plot/items/scatter.py8
-rw-r--r--silx/gui/plot/stats/stats.py52
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py168
-rw-r--r--silx/gui/plot/test/testPixelIntensityHistoAction.py16
-rwxr-xr-xsilx/gui/plot/test/testPlotWidget.py179
-rw-r--r--silx/gui/plot/tools/PositionInfo.py90
-rw-r--r--silx/gui/plot/tools/RadarView.py361
-rw-r--r--silx/gui/plot/tools/profile/core.py5
-rw-r--r--silx/gui/plot/tools/profile/manager.py6
-rw-r--r--silx/gui/plot/tools/profile/rois.py28
-rw-r--r--silx/gui/plot3d/items/__init__.py4
-rw-r--r--silx/gui/plot3d/items/image.py251
-rw-r--r--silx/gui/plot3d/scene/primitives.py9
-rw-r--r--silx/gui/plot3d/test/testSceneWindow.py40
-rw-r--r--silx/gui/plot3d/tools/PositionInfoWidget.py4
-rw-r--r--silx/gui/qt/__init__.py4
-rw-r--r--silx/gui/setup.py3
-rw-r--r--silx/gui/utils/glutils/__init__.py (renamed from silx/gui/utils/glutils.py)13
-rw-r--r--silx/gui/widgets/FlowLayout.py2
-rw-r--r--silx/gui/widgets/UrlSelectionTable.py4
-rw-r--r--silx/image/bilinear.pyx271
-rw-r--r--silx/image/test/test_bilinear.py52
-rw-r--r--silx/io/dictdump.py303
-rw-r--r--silx/io/h5py_utils.py317
-rw-r--r--silx/io/spech5.py30
-rw-r--r--silx/io/test/__init__.py2
-rw-r--r--silx/io/test/test_dictdump.py411
-rw-r--r--silx/io/test/test_h5py_utils.py397
-rw-r--r--silx/io/url.py2
-rw-r--r--silx/opencl/codec/test/test_byte_offset.py13
-rw-r--r--silx/opencl/common.py4
-rw-r--r--silx/opencl/processing.py48
-rw-r--r--silx/opencl/statistics.py44
-rw-r--r--silx/opencl/test/__init__.py4
-rw-r--r--silx/opencl/test/test_doubleword.py258
-rw-r--r--silx/opencl/test/test_kahan.py10
-rw-r--r--silx/opencl/test/test_stats.py42
-rw-r--r--silx/resources/opencl/doubleword.cl115
-rw-r--r--silx/resources/opencl/statistics.cl177
-rw-r--r--silx/test/__init__.py9
-rw-r--r--silx/utils/retry.py264
-rwxr-xr-xsilx/utils/test/__init__.py2
-rw-r--r--silx/utils/test/test_retry.py179
-rwxr-xr-xsilx/utils/test/test_testutils.py9
-rwxr-xr-xsilx/utils/testutils.py15
77 files changed, 5517 insertions, 1245 deletions
diff --git a/silx/__main__.py b/silx/__main__.py
index a971390..f832a09 100644
--- a/silx/__main__.py
+++ b/silx/__main__.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,6 +38,7 @@ __date__ = "07/06/2018"
import logging
logging.basicConfig()
+import multiprocessing
import sys
from silx.utils.launcher import Launcher
import silx._version
@@ -52,6 +53,8 @@ def main():
:rtype: int
:returns: The execution status
"""
+ multiprocessing.freeze_support()
+
launcher = Launcher(prog="silx", version=silx._version.version)
launcher.add_command("view",
module_name="silx.app.view.main",
diff --git a/silx/app/view/test/test_launcher.py b/silx/app/view/test/test_launcher.py
index aabccf0..5f03de9 100644
--- a/silx/app/view/test/test_launcher.py
+++ b/silx/app/view/test/test_launcher.py
@@ -30,7 +30,9 @@ __date__ = "07/06/2018"
import os
+import shutil
import sys
+import tempfile
import unittest
import logging
import subprocess
@@ -74,38 +76,46 @@ class TestLauncher(unittest.TestCase):
result = e.args[0]
self.assertEqual(result, 0)
- def executeCommandLine(self, command_line, env):
+ def executeAsScript(self, filename, *args):
"""Execute a command line.
Log output as debug in case of bad return code.
"""
- _logger.info("Execute: %s", " ".join(command_line))
- p = subprocess.Popen(command_line,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env)
- out, err = p.communicate()
- _logger.info("Return code: %d", p.returncode)
- try:
- out = out.decode('utf-8')
- except UnicodeError:
- pass
- try:
- err = err.decode('utf-8')
- except UnicodeError:
- pass
-
- if p.returncode != 0:
- _logger.info("stdout:")
- _logger.info("%s", out)
- _logger.info("stderr:")
- _logger.info("%s", err)
- else:
- _logger.debug("stdout:")
- _logger.debug("%s", out)
- _logger.debug("stderr:")
- _logger.debug("%s", err)
- self.assertEqual(p.returncode, 0)
+ env = self.createTestEnv()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Copy file to temporary dir to avoid import from current dir.
+ script = os.path.join(tmpdir, 'launcher.py')
+ shutil.copyfile(filename, script)
+ command_line = [sys.executable, script] + list(args)
+
+ _logger.info("Execute: %s", " ".join(command_line))
+ p = subprocess.Popen(command_line,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env=env)
+ out, err = p.communicate()
+ _logger.info("Return code: %d", p.returncode)
+ try:
+ out = out.decode('utf-8')
+ except UnicodeError:
+ pass
+ try:
+ err = err.decode('utf-8')
+ except UnicodeError:
+ pass
+
+ if p.returncode != 0:
+ _logger.info("stdout:")
+ _logger.info("%s", out)
+ _logger.info("stderr:")
+ _logger.info("%s", err)
+ else:
+ _logger.debug("stdout:")
+ _logger.debug("%s", out)
+ _logger.debug("stderr:")
+ _logger.debug("%s", err)
+ self.assertEqual(p.returncode, 0)
def createTestEnv(self):
"""
@@ -120,18 +130,14 @@ class TestLauncher(unittest.TestCase):
Uses subprocess to avoid to parasite the current environment.
"""
- env = self.createTestEnv()
- commandLine = [sys.executable, main.__file__, "--help"]
- self.executeCommandLine(commandLine, env)
+ self.executeAsScript(main.__file__, "--help")
def testExecuteSilxViewHelp(self):
"""Test if the main module is well connected.
Uses subprocess to avoid to parasite the current environment.
"""
- env = self.createTestEnv()
- commandLine = [sys.executable, silx_main.__file__, "view", "--help"]
- self.executeCommandLine(commandLine, env)
+ self.executeAsScript(silx_main.__file__, "view", "--help")
def suite():
diff --git a/silx/gui/colors.py b/silx/gui/colors.py
index 4a96ae0..db837b5 100755
--- a/silx/gui/colors.py
+++ b/silx/gui/colors.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -45,7 +45,7 @@ from silx.utils import deprecation
from silx.resources import resource_filename as _resource_filename
-_logger = logging.getLogger(__file__)
+_logger = logging.getLogger(__name__)
try:
import silx.gui.utils.matplotlib # noqa Initalize matplotlib
diff --git a/silx/gui/data/ArrayTableModel.py b/silx/gui/data/ArrayTableModel.py
index 8805241..b7bd9c4 100644
--- a/silx/gui/data/ArrayTableModel.py
+++ b/silx/gui/data/ArrayTableModel.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -74,6 +74,10 @@ class ArrayTableModel(qt.QAbstractTableModel):
:param sequence[int] perspective: See documentation
of :meth:`setPerspective`.
"""
+
+ MAX_NUMBER_OF_SECTIONS = 10e6
+ """Maximum number of displayed rows and columns"""
+
def __init__(self, parent=None, data=None, perspective=None):
qt.QAbstractTableModel.__init__(self, parent)
@@ -173,7 +177,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
if row_dim is None:
# 0-D and 1-D arrays
return 1
- return self._array.shape[row_dim]
+ return min(self._array.shape[row_dim], self.MAX_NUMBER_OF_SECTIONS)
def columnCount(self, parent_idx=None):
"""QAbstractTableModel method
@@ -182,14 +186,55 @@ class ArrayTableModel(qt.QAbstractTableModel):
if col_dim is None:
# 0-D array
return 1
- return self._array.shape[col_dim]
+ return min(self._array.shape[col_dim], self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClipped(self, orientation=qt.Qt.Vertical) -> bool:
+ """Returns whether or not array is clipped in a given orientation"""
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return (dim is not None and
+ self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClippedIndex(self, index) -> bool:
+ """Returns whether or not index's cell represents clipped data."""
+ if not index.isValid():
+ return False
+ if index.row() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Vertical)
+ if index.column() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Horizontal)
+ return False
+
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
def data(self, index, role=qt.Qt.DisplayRole):
"""QAbstractTableModel method to access data values
in the format ready to be displayed"""
if index.isValid():
- selection = self._getIndexTuple(index.row(),
- index.column())
+ if self.__isClippedIndex(index): # Special displayed for clipped data
+ return self.__clippedData(role)
+
+ row, column = index.row(), index.column()
+
+ # When clipped, display last data of the array in last column of the table
+ if (self.__isClipped(qt.Qt.Vertical) and
+ row == self.MAX_NUMBER_OF_SECTIONS - 1):
+ row = self._array.shape[self._getRowDim()] - 1
+ if (self.__isClipped(qt.Qt.Horizontal) and
+ column == self.MAX_NUMBER_OF_SECTIONS - 1):
+ column = self._array.shape[self._getColumnDim()] - 1
+
+ selection = self._getIndexTuple(row, column)
+
if role == qt.Qt.DisplayRole:
return self._formatter.toString(self._array[selection], self._array.dtype)
@@ -224,17 +269,30 @@ class ArrayTableModel(qt.QAbstractTableModel):
"""QAbstractTableModel method
Return the 0-based row or column index, for display in the
horizontal and vertical headers"""
+ if self.__isClipped(orientation): # Header is clipped
+ if section == self.MAX_NUMBER_OF_SECTIONS - 2:
+ # Represent clipped data
+ return self.__clippedData(role)
+
+ elif section == self.MAX_NUMBER_OF_SECTIONS - 1:
+ # Display last index from data not table
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return str(self._array.shape[dim] - 1)
+ else:
+ return None
+
if role == qt.Qt.DisplayRole:
- if orientation == qt.Qt.Vertical:
- return "%d" % section
- if orientation == qt.Qt.Horizontal:
- return "%d" % section
+ return "%d" % section
return None
def flags(self, index):
"""QAbstractTableModel method to inform the view whether data
is editable or not."""
- if not self._editable:
+ if not self._editable or self.__isClippedIndex(index):
return qt.QAbstractTableModel.flags(self, index)
return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable
diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py
index d9958de..b18a813 100644
--- a/silx/gui/data/DataViews.py
+++ b/silx/gui/data/DataViews.py
@@ -1042,6 +1042,8 @@ class _Plot2dView(DataView):
widget.setKeepDataAspectRatio(True)
widget.getXAxis().setLabel('X')
widget.getYAxis().setLabel('Y')
+ maskToolsWidget = widget.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
return widget
def clear(self):
@@ -1156,6 +1158,8 @@ class _ComplexImageView(DataView):
widget.getPlot().setKeepDataAspectRatio(True)
widget.getXAxis().setLabel('X')
widget.getYAxis().setLabel('Y')
+ maskToolsWidget = widget.getPlot().getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
return widget
def clear(self):
@@ -1254,6 +1258,8 @@ class _StackView(DataView):
widget.setLabels(self.axesNames(None, None))
# hide default option panel
widget.setOptionVisible(False)
+ maskToolWidget = widget.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
return widget
def clear(self):
diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py
index 271b267..be7d0e3 100644
--- a/silx/gui/data/NXdataWidgets.py
+++ b/silx/gui/data/NXdataWidgets.py
@@ -371,6 +371,8 @@ class ArrayImagePlot(qt.QWidget):
normalization=Colormap.LINEAR))
self._plot.getIntensityHistogramAction().setVisible(True)
self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
# not closable
self._selector = NumpyAxesSelector(self)
@@ -511,7 +513,7 @@ class ArrayImagePlot(qt.QWidget):
self._plot.getYAxis().setScale('linear')
self._plot.addImage(image, legend=legend,
origin=origin, scale=scale,
- replace=True)
+ replace=True, resetzoom=False)
else:
xaxisscale, yaxisscale = self._axis_scales
@@ -587,6 +589,8 @@ class ArrayComplexImagePlot(qt.QWidget):
self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getPlot().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
# not closable
self._selector = NumpyAxesSelector(self)
@@ -769,6 +773,9 @@ class ArrayStackPlot(qt.QWidget):
self.__x_axis_name = None
self._stack_view = StackView(self)
+ maskToolWidget = self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
self._hline = qt.QFrame(self)
self._hline.setFrameStyle(qt.QFrame.HLine)
self._hline.setFrameShadow(qt.QFrame.Sunken)
diff --git a/silx/gui/data/RecordTableView.py b/silx/gui/data/RecordTableView.py
index b1b7dcd..2c0011a 100644
--- a/silx/gui/data/RecordTableView.py
+++ b/silx/gui/data/RecordTableView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -149,6 +149,10 @@ class RecordTableModel(qt.QAbstractTableModel):
:param qt.QObject parent: Parent object
:param numpy.ndarray data: A numpy array or a h5py dataset
"""
+
+ MAX_NUMBER_OF_ROWS = 10e6
+ """Maximum number of display values of the dataset"""
+
def __init__(self, parent=None, data=None):
qt.QAbstractTableModel.__init__(self, parent)
@@ -170,7 +174,7 @@ class RecordTableModel(qt.QAbstractTableModel):
elif not self.__is_array:
return 1
else:
- return len(self.__data)
+ return min(len(self.__data), self.MAX_NUMBER_OF_ROWS)
def columnCount(self, parent_idx=None):
"""Returns number of columns to be displayed in table"""
@@ -179,6 +183,15 @@ class RecordTableModel(qt.QAbstractTableModel):
else:
return len(self.__fields)
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
+
def data(self, index, role=qt.Qt.DisplayRole):
"""QAbstractTableModel method to access data values
in the format ready to be displayed"""
@@ -188,10 +201,19 @@ class RecordTableModel(qt.QAbstractTableModel):
if self.__data is None:
return None
+ # Special display of one before last data for clipped table
+ if self.__isClipped() and index.row() == self.rowCount() - 2:
+ return self.__clippedData(role)
+
if self.__is_array:
- if index.row() >= len(self.__data):
+ row = index.row()
+ if row >= self.rowCount():
return None
- data = self.__data[index.row()]
+ elif self.__isClipped() and row == self.rowCount() - 1:
+ # Clipped array, display last value at the end
+ data = self.__data[-1]
+ else:
+ data = self.__data[row]
else:
if index.row() > 0:
return None
@@ -221,10 +243,18 @@ class RecordTableModel(qt.QAbstractTableModel):
# PyQt4 send -1 when there is columns but no rows
return None
+ # Handle clipping of huge tables
+ if (self.__isClipped() and
+ orientation == qt.Qt.Vertical and
+ section == self.rowCount() - 2):
+ return self.__clippedData(role)
+
if role == qt.Qt.DisplayRole:
if orientation == qt.Qt.Vertical:
if not self.__is_array:
return "Scalar"
+ elif section == self.MAX_NUMBER_OF_ROWS - 1:
+ return str(len(self.__data) - 1)
else:
return str(section)
if orientation == qt.Qt.Horizontal:
@@ -246,6 +276,10 @@ class RecordTableModel(qt.QAbstractTableModel):
"""
return qt.QAbstractTableModel.flags(self, index)
+ def __isClipped(self) -> bool:
+ """Returns whether the displayed array is clipped or not"""
+ return self.__data is not None and self.__is_array and len(self.__data) > self.MAX_NUMBER_OF_ROWS
+
def setArrayData(self, data):
"""Set the data array and the viewing perspective.
diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py
index 7785ac5..87081ed 100644
--- a/silx/gui/data/test/test_arraywidget.py
+++ b/silx/gui/data/test/test_arraywidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,6 +34,7 @@ import numpy
from silx.gui import qt
from silx.gui.data import ArrayTableWidget
+from silx.gui.data.ArrayTableModel import ArrayTableModel
from silx.gui.utils.testutils import TestCaseQt
import h5py
@@ -186,6 +187,18 @@ class TestArrayWidget(TestCaseQt):
b1 = self.aw.getData(copy=False)
self.assertIs(b0, b1)
+ def testClipping(self):
+ """Test clipping of large arrays"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ data = numpy.arange(ArrayTableModel.MAX_NUMBER_OF_SECTIONS + 10)
+
+ for shape in [(1, -1), (-1, 1)]:
+ with self.subTest(shape=shape):
+ self.aw.setArrayData(data.reshape(shape), editable=True)
+ self.qapp.processEvents()
+
class TestH5pyArrayWidget(TestCaseQt):
"""Basic test for ArrayTableWidget with a dataset.
diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py
index 6b5d83b..ca7ee97 100644
--- a/silx/gui/dialog/ColormapDialog.py
+++ b/silx/gui/dialog/ColormapDialog.py
@@ -59,12 +59,9 @@ The updates of the colormap description are also available through the signal:
:attr:`ColormapDialog.sigColormapChanged`.
""" # noqa
-from __future__ import division
-
__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
-__date__ = "27/11/2018"
-
+__date__ = "08/12/2020"
import enum
import logging
@@ -73,7 +70,7 @@ import numpy
from .. import qt
from .. import utils
-from ..colors import Colormap
+from ..colors import Colormap, cursorColorForColormap
from ..plot import PlotWidget
from ..plot.items.axis import Axis
from ..plot.items import BoundingRect
@@ -84,12 +81,14 @@ from silx.gui.plot import items
from silx.gui import icons
from silx.gui.qt import inspect as qtinspect
from silx.gui.widgets.ColormapNameComboBox import ColormapNameComboBox
+from silx.gui.widgets.WaitingPushButton import WaitingPushButton
from silx.math.histogram import Histogramnd
from silx.utils import deprecation
+from silx.gui.plot.items.roi import RectangleROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
_logger = logging.getLogger(__name__)
-
_colormapIconPreview = {}
@@ -510,6 +509,7 @@ class _ColormapHistogram(qt.QWidget):
:returns: Tuple{float, float}
"""
scale = self._plot.getXAxis().getScale()
+
def isDisplayable(pos):
if pos is None:
return False
@@ -770,7 +770,8 @@ class _ColormapHistogram(qt.QWidget):
else:
histogram = numpy.array(histogram, copy=True)
bin_edges = numpy.array(bin_edges, copy=True)
- norm_histogram = histogram / max(histogram)
+ with numpy.errstate(invalid='ignore'):
+ norm_histogram = histogram / numpy.nanmax(histogram)
self._plot.addHistogram(norm_histogram,
bin_edges,
legend="Data",
@@ -917,6 +918,27 @@ class ColormapDialog(qt.QDialog):
self._histoWidget.sigRangeMoving.connect(self._histogramRangeMoving)
self._histoWidget.sigRangeMoved.connect(self._histogramRangeMoved)
+ # Scale to buttons
+ self._visibleAreaButton = qt.QPushButton(self)
+ self._visibleAreaButton.setEnabled(False)
+ self._visibleAreaButton.setText("Visible Area")
+ self._visibleAreaButton.clicked.connect(
+ self._handleScaleToVisibleAreaClicked,
+ type=qt.Qt.QueuedConnection)
+
+ # Place-holder for selected area ROI manager
+ self._roiForColormapManager = None
+
+ self._selectedAreaButton = WaitingPushButton(self)
+ self._selectedAreaButton.setEnabled(False)
+ self._selectedAreaButton.setText("Selection")
+ self._selectedAreaButton.setIcon(icons.getQIcon("add-shape-rectangle"))
+ self._selectedAreaButton.setCheckable(True)
+ self._selectedAreaButton.setDisabledWhenWaiting(False)
+ self._selectedAreaButton.toggled.connect(
+ self._handleScaleToSelectionToggled,
+ type=qt.Qt.QueuedConnection)
+
# define modal buttons
types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
self._buttonsModal = qt.QDialogButtonBox(parent=self)
@@ -955,6 +977,16 @@ class ColormapDialog(qt.QDialog):
label.setToolTip("Mode for autoscale. Algorithm used to find range in auto scale.")
formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
formLayout.addRow(label, autoScaleCombo)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._visibleAreaButton)
+ layout.addWidget(self._selectedAreaButton)
+ self._scaleToAreaGroup = qt.QGroupBox('Scale to:', self)
+ self._scaleToAreaGroup.setLayout(layout)
+ self._scaleToAreaGroup.setVisible(False)
+ formLayout.addRow(self._scaleToAreaGroup)
+
formLayout.addRow(self._buttonsModal)
formLayout.addRow(self._buttonsNonModal)
formLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
@@ -965,7 +997,9 @@ class ColormapDialog(qt.QDialog):
self.setTabOrder(self._minValue, self._maxValue)
self.setTabOrder(self._maxValue, self._autoButtons)
self.setTabOrder(self._autoButtons, self._autoScaleCombo)
- self.setTabOrder(self._autoScaleCombo, self._buttonsModal)
+ self.setTabOrder(self._autoScaleCombo, self._visibleAreaButton)
+ self.setTabOrder(self._visibleAreaButton, self._selectedAreaButton)
+ self.setTabOrder(self._selectedAreaButton, self._buttonsModal)
self.setTabOrder(self._buttonsModal, self._buttonsNonModal)
self.setFixedSize(self.sizeHint())
@@ -1094,9 +1128,9 @@ class ColormapDialog(qt.QDialog):
if data.ndim == 3: # RGB(A) images
_logger.info('Converting current image from RGB(A) to grayscale\
in order to compute the intensity distribution')
- data = (data[:, :, 0] * 0.299 +
- data[:, :, 1] * 0.587 +
- data[:, :, 2] * 0.114)
+ data = (data[:,:, 0] * 0.299 +
+ data[:,:, 1] * 0.587 +
+ data[:,:, 2] * 0.114)
# bad hack: get 256 continuous bins in the case we have a B&W
normalizeData = True
@@ -1141,7 +1175,7 @@ class ColormapDialog(qt.QDialog):
bins = histogram.edges[0]
if normalizeData:
if scale == Colormap.LOGARITHM:
- bins = 10**bins
+ bins = 10 ** bins
return histogram.histo, bins
def _getItem(self):
@@ -1173,6 +1207,7 @@ class ColormapDialog(qt.QDialog):
raise ValueError("Item %s is not supported" % item)
self._item = weakref.ref(item, self._itemAboutToFinalize)
finally:
+ self._syncScaleToButtonsEnabled()
self._dataRange = None
self._histogramData = None
self._invalidateData()
@@ -1194,6 +1229,7 @@ class ColormapDialog(qt.QDialog):
return
self._item = None
+ self._syncScaleToButtonsEnabled()
if data is None:
self._data = None
self._itemHolder = None
@@ -1311,6 +1347,55 @@ class ColormapDialog(qt.QDialog):
colormap.setVRange(xmin, xmax)
self._updateWidgetRange()
+ def setColormapRangeFromDataBounds(self, bounds):
+ """Set the range of the colormap from current item and rect.
+
+ If there is no ColormapMixIn item attached to the ColormapDialog,
+ nothing is done.
+
+ :param Union[List[float],None] bounds:
+ (xmin, xmax, ymin, ymax) Rectangular region in data space
+ """
+ if bounds is None:
+ return None # no-op
+
+ colormap = self.getColormap()
+ if colormap is None:
+ return # no-op
+
+ item = self._getItem()
+ if not isinstance(item, items.ColormapMixIn):
+ return None # no-op
+
+ data = item.getColormappedData(copy=False)
+
+ xmin, xmax, ymin, ymax = bounds
+
+ if isinstance(item, items.ImageBase):
+ ox, oy = item.getOrigin()
+ sx, sy = item.getScale()
+
+ ystart = max(0, int((ymin - oy) / sy))
+ ystop = max(0, int(numpy.ceil((ymax - oy) / sy)))
+ xstart = max(0, int((xmin - ox) / sx))
+ xstop = max(0, int(numpy.ceil((xmax - ox) / sx)))
+
+ subset = data[ystart:ystop, xstart:xstop]
+
+ elif isinstance(item, items.Scatter):
+ x = item.getXData(copy=False)
+ y = item.getYData(copy=False)
+ subset = data[
+ numpy.logical_and(
+ numpy.logical_and(xmin <= x, x <= xmax),
+ numpy.logical_and(ymin <= y, y <= ymax))]
+
+ if subset.size == 0:
+ return # no-op
+
+ vmin, vmax = colormap._computeAutoscaleRange(subset)
+ self._setColormapRange(vmin, vmax)
+
def _updateWidgetRange(self):
"""Update the colormap range displayed into the widget."""
xmin, xmax = self._getFiniteColormapRange()
@@ -1387,6 +1472,8 @@ class ColormapDialog(qt.QDialog):
if self._colormapChange.locked():
return
+ self._syncScaleToButtonsEnabled()
+
colormap = self.getColormap()
if colormap is None:
self._comboBoxColormap.setEnabled(False)
@@ -1591,6 +1678,73 @@ class ColormapDialog(qt.QDialog):
vmax = xmax
self._setColormapRange(vmin, vmax)
+ def _syncScaleToButtonsEnabled(self):
+ """Set the state of scale to buttons according to current item and colormap"""
+ colormap = self.getColormap()
+ enabled = self._item is not None and colormap is not None and colormap.isEditable()
+ self._scaleToAreaGroup.setVisible(enabled)
+ self._visibleAreaButton.setEnabled(enabled)
+ if not enabled:
+ self._selectedAreaButton.setChecked(False)
+ self._selectedAreaButton.setEnabled(enabled)
+
+ def _handleScaleToVisibleAreaClicked(self):
+ """Set colormap range from current item's visible area"""
+ item = self._getItem()
+ if item is None:
+ return # no-op
+
+ bounds = item.getVisibleBounds()
+ if bounds is None:
+ return # no-op
+
+ self.setColormapRangeFromDataBounds(bounds)
+
+ def _handleScaleToSelectionToggled(self, checked=False):
+ """Handle toggle of scale to selected are button"""
+ # Reset any previous ROI manager
+ if self._roiForColormapManager is not None:
+ self._roiForColormapManager.clear()
+ self._roiForColormapManager.stop()
+ self._roiForColormapManager = None
+
+ if not checked: # Reset button status
+ self._selectedAreaButton.setWaiting(False)
+ self._selectedAreaButton.setText("Selection")
+ return
+
+ item = self._getItem()
+ if item is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ plotWidget = item.getPlot()
+ if plotWidget is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ self._selectedAreaButton.setWaiting(True)
+ self._selectedAreaButton.setText("Draw Area...")
+
+ self._roiForColormapManager = RegionOfInterestManager(parent=plotWidget)
+ cmap = self.getColormap()
+ self._roiForColormapManager.setColor(
+ 'black' if cmap is None else cursorColorForColormap(cmap.getName()))
+ self._roiForColormapManager.sigInteractiveModeFinished.connect(
+ self.__roiInteractiveModeFinished)
+ self._roiForColormapManager.sigInteractiveRoiFinalized.connect(self.__roiFinalized)
+ self._roiForColormapManager.start(RectangleROI)
+
+ def __roiInteractiveModeFinished(self):
+ self._selectedAreaButton.setChecked(False)
+
+ def __roiFinalized(self, roi):
+ self._selectedAreaButton.setChecked(False)
+ if roi is not None:
+ ox, oy = roi.getOrigin()
+ width, height = roi.getSize()
+ self.setColormapRangeFromDataBounds((ox, ox+width, oy, oy+height))
+
def keyPressEvent(self, event):
"""Override key handling.
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
index eff7689..d869825 100644
--- a/silx/gui/plot/ColorBar.py
+++ b/silx/gui/plot/ColorBar.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -36,6 +36,7 @@ import numpy
from ._utils import ticklayout
from .. import qt
+from ..qt import inspect as qt_inspect
from silx.gui import colors
_logger = logging.getLogger(__name__)
@@ -112,14 +113,15 @@ class ColorBarWidget(qt.QWidget):
def _disconnectPlot(self):
"""Disconnect from Plot signals"""
- plot = self.getPlot()
- if plot is not None and self._isConnected:
+ if self._isConnected:
self._isConnected = False
- plot.sigActiveImageChanged.disconnect(
- self._activeImageChanged)
- plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChanged)
- plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
+ plot = self.getPlot()
+ if plot is not None and qt_inspect.isValid(plot):
+ plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChanged)
+ plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
def _connectPlot(self):
"""Connect to Plot signals"""
diff --git a/silx/gui/plot/Colormap.py b/silx/gui/plot/Colormap.py
index e797d89..22fea7f 100644
--- a/silx/gui/plot/Colormap.py
+++ b/silx/gui/plot/Colormap.py
@@ -25,11 +25,9 @@
"""Deprecated module providing the Colormap object
"""
-from __future__ import absolute_import
-
__authors__ = ["T. Vincent", "H.Payno"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "27/11/2020"
import silx.utils.deprecation
diff --git a/silx/gui/plot/ImageStack.py b/silx/gui/plot/ImageStack.py
index 3b652ca..fe4b451 100644
--- a/silx/gui/plot/ImageStack.py
+++ b/silx/gui/plot/ImageStack.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -41,7 +41,7 @@ import threading
import typing
import logging
-_logger = logging.getLogger(__file__)
+_logger = logging.getLogger(__name__)
class _PlotWithWaitingLabel(qt.QWidget):
@@ -71,6 +71,7 @@ class _PlotWithWaitingLabel(qt.QWidget):
def __init__(self, parent):
super(_PlotWithWaitingLabel, self).__init__(parent=parent)
+ self._autoResetZoom = True
layout = qt.QStackedLayout(self)
layout.setStackingMode(qt.QStackedLayout.StackAll)
@@ -88,6 +89,24 @@ class _PlotWithWaitingLabel(qt.QWidget):
super(_PlotWithWaitingLabel, self).close()
self.updateThread.stop()
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._autoResetZoom = reset
+ if self._autoResetZoom:
+ self._plot.resetZoom()
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._autoResetZoom
+
def setWaiting(self, activate=True):
if activate is True:
self._plot.clear()
@@ -97,7 +116,7 @@ class _PlotWithWaitingLabel(qt.QWidget):
def setData(self, data):
self.setWaiting(activate=False)
- self._plot.addImage(data=data)
+ self._plot.addImage(data=data, resetzoom=self._autoResetZoom)
def clear(self):
self._plot.clear()
@@ -160,8 +179,7 @@ class UrlList(qt.QWidget):
sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly)
if sel_items is None:
_logger.warning(url.path(), ' is not registered in the list.')
- else:
- assert len(sel_items) == 1
+ elif len(sel_items) > 0:
item = sel_items[0]
self._listWidget.setCurrentItem(item)
self.sigCurrentUrlChanged.emit(item.text())
@@ -601,3 +619,18 @@ class ImageStack(qt.QMainWindow):
"""display a simple image of loading..."""
self._plot.setWaiting(activate=True)
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._plot.setAutoResetZoom(reset)
+
+ def isAutoResetZoom(self) -> bool:
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._plot.isAutoResetZoom()
diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py
index 8cc0cc6..1befe58 100644
--- a/silx/gui/plot/ImageView.py
+++ b/silx/gui/plot/ImageView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -47,9 +47,13 @@ __date__ = "26/04/2018"
import logging
import numpy
+import collections
+from typing import Union
+import weakref
import silx
from .. import qt
+from .. import colors
from . import items, PlotWindow, PlotWidget, actions
from ..colors import Colormap
@@ -57,192 +61,262 @@ from ..colors import cursorColorForColormap
from .tools import LimitsToolBar
from .Profile import ProfileToolBar
from ...utils.proxy import docstring
+from ...utils.enum import Enum
+from .tools.RadarView import RadarView
+from .utils.axis import SyncAxes
+from ..utils import blockSignals
+from . import _utils
+from .tools.profile import manager
+from .tools.profile import rois
_logger = logging.getLogger(__name__)
-# RadarView ###################################################################
+ProfileSumResult = collections.namedtuple("ProfileResult",
+ ["dataXRange", "dataYRange",
+ 'histoH', 'histoHRange',
+ 'histoV', 'histoVRange',
+ "xCoords", "xData",
+ "yCoords", "yData"])
-class RadarView(qt.QGraphicsView):
- """Widget presenting a synthetic view of a 2D area and
- the current visible area.
- Coordinates are as in QGraphicsView:
- x goes from left to right and y goes from top to bottom.
- This widget preserves the aspect ratio of the areas.
+def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
+ """
+ Compute a full vertical and horizontal profile on an image item using a
+ a range in the plot referential.
- The 2D area and the visible area can be set with :meth:`setDataRect`
- and :meth:`setVisibleRect`.
- When the visible area has been dragged by the user, its new position
- is signaled by the *visibleRectDragged* signal.
+ Optionally takes a previous computed result to be able to skip the
+ computation.
- It is possible to invert the direction of the axes by using the
- :meth:`scale` method of QGraphicsView.
+ :rtype: ProfileSumResult
"""
+ data = imageItem.getValueData(copy=False)
+ origin = imageItem.getOrigin()
+ scale = imageItem.getScale()
+ height, width = data.shape
+
+ xMin, xMax = xRange
+ yMin, yMax = yRange
+
+ # Convert plot area limits to image coordinates
+ # and work in image coordinates (i.e., in pixels)
+ xMin = int((xMin - origin[0]) / scale[0])
+ xMax = int((xMax - origin[0]) / scale[0])
+ yMin = int((yMin - origin[1]) / scale[1])
+ yMax = int((yMax - origin[1]) / scale[1])
+
+ if (xMin >= width or xMax < 0 or
+ yMin >= height or yMax < 0):
+ return None
+
+ # The image is at least partly in the plot area
+ # Get the visible bounds in image coords (i.e., in pixels)
+ subsetXMin = 0 if xMin < 0 else xMin
+ subsetXMax = (width if xMax >= width else xMax) + 1
+ subsetYMin = 0 if yMin < 0 else yMin
+ subsetYMax = (height if yMax >= height else yMax) + 1
+
+ if cache is not None:
+ if ((subsetXMin, subsetXMax) == cache.dataXRange and
+ (subsetYMin, subsetYMax) == cache.dataYRange):
+ # The visible area of data is the same
+ return cache
+
+ # Rebuild histograms for visible area
+ visibleData = data[subsetYMin:subsetYMax,
+ subsetXMin:subsetXMax]
+ histoHVisibleData = numpy.nansum(visibleData, axis=0)
+ histoVVisibleData = numpy.nansum(visibleData, axis=1)
+ histoHMin = numpy.nanmin(histoHVisibleData)
+ histoHMax = numpy.nanmax(histoHVisibleData)
+ histoVMin = numpy.nanmin(histoVVisibleData)
+ histoVMax = numpy.nanmax(histoVVisibleData)
+
+ # Convert to histogram curve and update plots
+ # Taking into account origin and scale
+ coords = numpy.arange(2 * histoHVisibleData.size)
+ xCoords = (coords + 1) // 2 + subsetXMin
+ xCoords = origin[0] + scale[0] * xCoords
+ xData = numpy.take(histoHVisibleData, coords // 2)
+ coords = numpy.arange(2 * histoVVisibleData.size)
+ yCoords = (coords + 1) // 2 + subsetYMin
+ yCoords = origin[1] + scale[1] * yCoords
+ yData = numpy.take(histoVVisibleData, coords // 2)
+
+ result = ProfileSumResult(
+ dataXRange=(subsetXMin, subsetXMax),
+ dataYRange=(subsetYMin, subsetYMax),
+ histoH=histoHVisibleData,
+ histoHRange=(histoHMin, histoHMax),
+ histoV=histoVVisibleData,
+ histoVRange=(histoVMin, histoVMax),
+ xCoords=xCoords,
+ xData=xData,
+ yCoords=yCoords,
+ yData=yData)
+
+ return result
+
+
+class _SideHistogram(PlotWidget):
+ """
+ Widget displaying one of the side profile of the ImageView.
- visibleRectDragged = qt.Signal(float, float, float, float)
- """Signals that the visible rectangle has been dragged.
-
- It provides: left, top, width, height in data coordinates.
+ Implement ProfileWindow
"""
- _DATA_PEN = qt.QPen(qt.QColor('white'))
- _DATA_BRUSH = qt.QBrush(qt.QColor('light gray'))
- _VISIBLE_PEN = qt.QPen(qt.QColor('red'))
- _VISIBLE_PEN.setWidth(2)
- _VISIBLE_PEN.setCosmetic(True)
- _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
- _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image'
-
- _PIXMAP_SIZE = 256
-
- class _DraggableRectItem(qt.QGraphicsRectItem):
- """RectItem which signals its change through visibleRectDragged."""
- def __init__(self, *args, **kwargs):
- super(RadarView._DraggableRectItem, self).__init__(
- *args, **kwargs)
-
- self._previousCursor = None
- self.setFlag(qt.QGraphicsItem.ItemIsMovable)
- self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges)
- self.setAcceptHoverEvents(True)
- self._ignoreChange = False
- self._constraint = 0, 0, 0, 0
-
- def setConstraintRect(self, left, top, width, height):
- """Set the constraint rectangle for dragging.
-
- The coordinates are in the _DraggableRectItem coordinate system.
-
- This constraint only applies to modification through interaction
- (i.e., this constraint is not applied to change through API).
-
- If the _DraggableRectItem is smaller than the constraint rectangle,
- the _DraggableRectItem remains within the constraint rectangle.
- If the _DraggableRectItem is wider than the constraint rectangle,
- the constraint rectangle remains within the _DraggableRectItem.
- """
- self._constraint = left, left + width, top, top + height
-
- def setPos(self, *args, **kwargs):
- """Overridden to ignore changes from API in itemChange."""
- self._ignoreChange = True
- super(RadarView._DraggableRectItem, self).setPos(*args, **kwargs)
- self._ignoreChange = False
-
- def moveBy(self, *args, **kwargs):
- """Overridden to ignore changes from API in itemChange."""
- self._ignoreChange = True
- super(RadarView._DraggableRectItem, self).moveBy(*args, **kwargs)
- self._ignoreChange = False
-
- def itemChange(self, change, value):
- """Callback called before applying changes to the item."""
- if (change == qt.QGraphicsItem.ItemPositionChange and
- not self._ignoreChange):
- # Makes sure that the visible area is in the data
- # or that data is in the visible area if area is too wide
- x, y = value.x(), value.y()
- xMin, xMax, yMin, yMax = self._constraint
-
- if self.rect().width() <= (xMax - xMin):
- if x < xMin:
- value.setX(xMin)
- elif x > xMax - self.rect().width():
- value.setX(xMax - self.rect().width())
- else:
- if x > xMin:
- value.setX(xMin)
- elif x < xMax - self.rect().width():
- value.setX(xMax - self.rect().width())
-
- if self.rect().height() <= (yMax - yMin):
- if y < yMin:
- value.setY(yMin)
- elif y > yMax - self.rect().height():
- value.setY(yMax - self.rect().height())
- else:
- if y > yMin:
- value.setY(yMin)
- elif y < yMax - self.rect().height():
- value.setY(yMax - self.rect().height())
-
- if self.pos() != value:
- # Notify change through signal
- views = self.scene().views()
- assert len(views) == 1
- views[0].visibleRectDragged.emit(
- value.x() + self.rect().left(),
- value.y() + self.rect().top(),
- self.rect().width(),
- self.rect().height())
-
- return value
-
- return super(RadarView._DraggableRectItem, self).itemChange(
- change, value)
-
- def hoverEnterEvent(self, event):
- """Called when the mouse enters the rectangle area"""
- self._previousCursor = self.cursor()
- self.setCursor(qt.Qt.OpenHandCursor)
-
- def hoverLeaveEvent(self, event):
- """Called when the mouse leaves the rectangle area"""
- if self._previousCursor is not None:
- self.setCursor(self._previousCursor)
- self._previousCursor = None
-
- def __init__(self, parent=None):
- self._scene = qt.QGraphicsScene()
- self._dataRect = self._scene.addRect(0, 0, 1, 1,
- self._DATA_PEN,
- self._DATA_BRUSH)
- self._visibleRect = self._DraggableRectItem(0, 0, 1, 1)
- self._visibleRect.setPen(self._VISIBLE_PEN)
- self._visibleRect.setBrush(self._VISIBLE_BRUSH)
- self._scene.addItem(self._visibleRect)
-
- super(RadarView, self).__init__(self._scene, parent)
- self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
- self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
- self.setFocusPolicy(qt.Qt.NoFocus)
- self.setStyleSheet('border: 0px')
- self.setToolTip(self._TOOLTIP)
-
- def sizeHint(self):
- # """Overridden to avoid sizeHint to depend on content size."""
- return self.minimumSizeHint()
-
- def wheelEvent(self, event):
- # """Overridden to disable vertical scrolling with wheel."""
- event.ignore()
-
- def resizeEvent(self, event):
- # """Overridden to fit current content to new size."""
- self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
- super(RadarView, self).resizeEvent(event)
-
- def setDataRect(self, left, top, width, height):
- """Set the bounds of the data rectangular area.
-
- This sets the coordinate system.
+ sigClose = qt.Signal()
+
+ sigMouseMoved = qt.Signal(float, float)
+
+ def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal):
+ super(_SideHistogram, self).__init__(parent=parent, backend=backend)
+ self._direction = direction
+ self.sigPlotSignal.connect(self._plotEvents)
+ self._color = "blue"
+ self.__profile = None
+ self.__profileSum = None
+
+ def _plotEvents(self, eventDict):
+ """Callback for horizontal histogram plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ self.sigMouseMoved.emit(eventDict['x'], eventDict['y'])
+
+ def setProfileColor(self, color):
+ self._color = color
+
+ def setProfileSum(self, result):
+ self.__profileSum = result
+ if self.__profile is None:
+ self.__drawProfileSum()
+
+ def prepareWidget(self, roi):
+ """Implements `ProfileWindow`"""
+ pass
+
+ def setRoiProfile(self, roi):
+ """Implements `ProfileWindow`"""
+ if roi is None:
+ return
+ self._roiColor = colors.rgba(roi.getColor())
+
+ def getProfile(self):
+ """Implements `ProfileWindow`"""
+ return self.__profile
+
+ def setProfile(self, data):
+ """Implements `ProfileWindow`"""
+ self.__profile = data
+ if data is None:
+ self.__drawProfileSum()
+ else:
+ self.__drawProfile()
+
+ def __drawProfileSum(self):
+ """Only draw the profile sum on the plot.
+
+ Other elements are removed
"""
- self._dataRect.setRect(left, top, width, height)
- self._visibleRect.setConstraintRect(left, top, width, height)
- self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ profileSum = self.__profileSum
+
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+
+ if profileSum is None:
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profileSum.xCoords, profileSum.xData
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profileSum.yData, profileSum.yCoords
+ else:
+ assert False
+
+ self.addCurve(xx, yy,
+ xlabel='', ylabel='',
+ legend="profilesum",
+ color=self._color,
+ linestyle='-',
+ selectable=False,
+ resetzoom=False)
+
+ self.__updateLimits()
- def setVisibleRect(self, left, top, width, height):
- """Set the visible rectangular area.
+ def __drawProfile(self):
+ """Only draw the profile on the plot.
- The coordinates are relative to the data rect.
+ Other elements are removed
"""
- self._visibleRect.setRect(0, 0, width, height)
- self._visibleRect.setPos(left, top)
- self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ profile = self.__profile
+
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+
+ if profile is None:
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+ self.setProfileSum(self.__profileSum)
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profile.coords, profile.profile
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profile.profile, profile.coords
+ else:
+ assert False
+
+ self.addCurve(xx,
+ yy,
+ legend="profile",
+ color=self._roiColor,
+ resetzoom=False)
+
+ self.__updateLimits()
+
+ def __updateLimits(self):
+ if self.__profile:
+ data = self.__profile.profile
+ vMin = numpy.nanmin(data)
+ vMax = numpy.nanmax(data)
+ elif self.__profileSum is not None:
+ if self._direction == qt.Qt.Horizontal:
+ vMin, vMax = self.__profileSum.histoHRange
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax = self.__profileSum.histoVRange
+ else:
+ assert False
+ else:
+ vMin, vMax = 0, 0
+
+ # Tune the result using the data margins
+ margins = self.getDataMargins()
+ if self._direction == qt.Qt.Horizontal:
+ _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax)
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0)
+ else:
+ assert False
+
+ if self._direction == qt.Qt.Horizontal:
+ dataAxis = self.getYAxis()
+ elif self._direction == qt.Qt.Vertical:
+ dataAxis = self.getXAxis()
+ else:
+ assert False
+ with blockSignals(dataAxis):
+ dataAxis.setLimits(vMin, vMax)
-# ImageView ###################################################################
class ImageView(PlotWindow):
"""Display a single image with horizontal and vertical histograms.
@@ -281,10 +355,20 @@ class ImageView(PlotWindow):
Row and columns are either Nan or integer values.
"""
+ class ProfileWindowBehavior(Enum):
+ """ImageView's profile window behavior options"""
+
+ POPUP = 'popup'
+ """All profiles are displayed in pop-up windows"""
+
+ EMBEDDED = 'embedded'
+ """Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+
def __init__(self, parent=None, backend=None):
self._imageLegend = '__ImageView__image' + str(id(self))
self._cache = None # Store currently visible data information
- self._updatingLimits = False
super(ImageView, self).__init__(parent=parent, backend=backend,
resetzoom=True, autoScale=False,
@@ -294,6 +378,11 @@ class ImageView(PlotWindow):
copy=True, save=True, print_=True,
control=False, position=False,
roi=False, mask=True)
+
+ # Enable mask synchronisation to use it in profiles
+ maskToolsWidget = self.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+
if parent is None:
self.setWindowTitle('ImageView')
@@ -302,44 +391,40 @@ class ImageView(PlotWindow):
self._initWidgets(backend)
- self.profile = ProfileToolBar(plot=self)
- """"Profile tools attached to this plot.
-
- See :class:`silx.gui.plot.PlotTools.ProfileToolBar`
- """
-
- self.addToolBar(self.profile)
-
- # Sync PlotBackend and ImageView
- self._updateYAxisInverted()
+ self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP
+ self.__profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.__profile)
def _initWidgets(self, backend):
"""Set-up layout and plots."""
- self._histoHPlot = PlotWidget(backend=backend, parent=self)
- self._histoHPlot.getWidgetHandle().setMinimumHeight(
- self.HISTOGRAMS_HEIGHT)
- self._histoHPlot.getWidgetHandle().setMaximumHeight(
- self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal)
+ widgetHandle = self._histoHPlot.getWidgetHandle()
+ widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
self._histoHPlot.setInteractiveMode('zoom')
- self._histoHPlot.sigPlotSignal.connect(self._histoHPlotCB)
+ self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1)
+ self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
+ self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical)
+ widgetHandle = self._histoVPlot.getWidgetHandle()
+ widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
+ self._histoVPlot.setInteractiveMode('zoom')
+ self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.)
+ self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
+ self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
self.setPanWithArrowKeys(True)
-
self.setInteractiveMode('zoom') # Color set in setColormap
self.sigPlotSignal.connect(self._imagePlotCB)
- self.getYAxis().sigInvertedChanged.connect(self._updateYAxisInverted)
self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
- self._histoVPlot = PlotWidget(backend=backend, parent=self)
- self._histoVPlot.getWidgetHandle().setMinimumWidth(
- self.HISTOGRAMS_HEIGHT)
- self._histoVPlot.getWidgetHandle().setMaximumWidth(
- self.HISTOGRAMS_HEIGHT)
- self._histoVPlot.setInteractiveMode('zoom')
- self._histoVPlot.sigPlotSignal.connect(self._histoVPlotCB)
-
self._radarView = RadarView(parent=self)
- self._radarView.visibleRectDragged.connect(self._radarViewCB)
+ self._radarView.setPlotWidget(self)
+
+ self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()])
+ self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()])
self.__setCentralWidget()
@@ -382,113 +467,12 @@ class ImageView(PlotWindow):
"""Update histograms content using current active image."""
activeImage = self.getActiveImage()
if activeImage is not None:
- wasUpdatingLimits = self._updatingLimits
- self._updatingLimits = True
-
- data = activeImage.getData(copy=False)
- origin = activeImage.getOrigin()
- scale = activeImage.getScale()
- height, width = data.shape
-
- xMin, xMax = self.getXAxis().getLimits()
- yMin, yMax = self.getYAxis().getLimits()
-
- # Convert plot area limits to image coordinates
- # and work in image coordinates (i.e., in pixels)
- xMin = int((xMin - origin[0]) / scale[0])
- xMax = int((xMax - origin[0]) / scale[0])
- yMin = int((yMin - origin[1]) / scale[1])
- yMax = int((yMax - origin[1]) / scale[1])
-
- if (xMin < width and xMax >= 0 and
- yMin < height and yMax >= 0):
- # The image is at least partly in the plot area
- # Get the visible bounds in image coords (i.e., in pixels)
- subsetXMin = 0 if xMin < 0 else xMin
- subsetXMax = (width if xMax >= width else xMax) + 1
- subsetYMin = 0 if yMin < 0 else yMin
- subsetYMax = (height if yMax >= height else yMax) + 1
-
- if (self._cache is None or
- subsetXMin != self._cache['dataXMin'] or
- subsetXMax != self._cache['dataXMax'] or
- subsetYMin != self._cache['dataYMin'] or
- subsetYMax != self._cache['dataYMax']):
- # The visible area of data has changed, update histograms
-
- # Rebuild histograms for visible area
- visibleData = data[subsetYMin:subsetYMax,
- subsetXMin:subsetXMax]
- histoHVisibleData = numpy.sum(visibleData, axis=0)
- histoVVisibleData = numpy.sum(visibleData, axis=1)
-
- self._cache = {
- 'dataXMin': subsetXMin,
- 'dataXMax': subsetXMax,
- 'dataYMin': subsetYMin,
- 'dataYMax': subsetYMax,
-
- 'histoH': histoHVisibleData,
- 'histoHMin': numpy.min(histoHVisibleData),
- 'histoHMax': numpy.max(histoHVisibleData),
-
- 'histoV': histoVVisibleData,
- 'histoVMin': numpy.min(histoVVisibleData),
- 'histoVMax': numpy.max(histoVVisibleData)
- }
-
- # Convert to histogram curve and update plots
- # Taking into account origin and scale
- coords = numpy.arange(2 * histoHVisibleData.size)
- xCoords = (coords + 1) // 2 + subsetXMin
- xCoords = origin[0] + scale[0] * xCoords
- xData = numpy.take(histoHVisibleData, coords // 2)
- self._histoHPlot.addCurve(xCoords, xData,
- xlabel='', ylabel='',
- replace=False,
- color=self.HISTOGRAMS_COLOR,
- linestyle='-',
- selectable=False)
- vMin = self._cache['histoHMin']
- vMax = self._cache['histoHMax']
- vOffset = 0.1 * (vMax - vMin)
- if vOffset == 0.:
- vOffset = 1.
- self._histoHPlot.getYAxis().setLimits(vMin - vOffset,
- vMax + vOffset)
-
- coords = numpy.arange(2 * histoVVisibleData.size)
- yCoords = (coords + 1) // 2 + subsetYMin
- yCoords = origin[1] + scale[1] * yCoords
- yData = numpy.take(histoVVisibleData, coords // 2)
- self._histoVPlot.addCurve(yData, yCoords,
- xlabel='', ylabel='',
- replace=False,
- color=self.HISTOGRAMS_COLOR,
- linestyle='-',
- selectable=False)
- vMin = self._cache['histoVMin']
- vMax = self._cache['histoVMax']
- vOffset = 0.1 * (vMax - vMin)
- if vOffset == 0.:
- vOffset = 1.
- self._histoVPlot.getXAxis().setLimits(vMin - vOffset,
- vMax + vOffset)
- else:
- self._dirtyCache()
- self._histoHPlot.remove(kind='curve')
- self._histoVPlot.remove(kind='curve')
-
- self._updatingLimits = wasUpdatingLimits
-
- def _updateRadarView(self):
- """Update radar view visible area.
-
- Takes care of y coordinate conversion.
- """
- xMin, xMax = self.getXAxis().getLimits()
- yMin, yMax = self.getYAxis().getLimits()
- self._radarView.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin)
+ xRange = self.getXAxis().getLimits()
+ yRange = self.getYAxis().getLimits()
+ result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache)
+ self._cache = result
+ self._histoHPlot.setProfileSum(result)
+ self._histoVPlot.setProfileSum(result)
# Plots event listeners
@@ -513,104 +497,49 @@ class ImageView(PlotWindow):
data[y][x])
elif eventDict['event'] == 'limitsChanged':
- self._updateHistogramsLimits()
-
- def _updateHistogramsLimits(self):
- # Do not handle histograms limitsChanged while
- # updating their limits from here.
- self._updatingLimits = True
-
- # Refresh histograms
self._updateHistograms()
- xMin, xMax = self.getXAxis().getLimits()
- yMin, yMax = self.getYAxis().getLimits()
+ def _mouseMovedOnHistoH(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
- # Set horizontal histo limits
- self._histoHPlot.getXAxis().setLimits(xMin, xMax)
+ xOrigin = activeImage.getOrigin()[0]
+ xScale = activeImage.getScale()[0]
- # Set vertical histo limits
- self._histoVPlot.getYAxis().setLimits(yMin, yMax)
+ minValue = xOrigin + xScale * self._cache.dataXRange[0]
- self._updateRadarView()
+ if x >= minValue:
+ data = self._cache.histoH
+ column = int((x - minValue) / xScale)
+ if column >= 0 and column < data.shape[0]:
+ self.valueChanged.emit(
+ float('nan'),
+ float(column + self._cache.dataXRange[0]),
+ data[column])
- self._updatingLimits = False
+ def _mouseMovedOnHistoV(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
- def _histoHPlotCB(self, eventDict):
- """Callback for horizontal histogram plot events."""
- if eventDict['event'] == 'mouseMoved':
- if self._cache is not None:
- activeImage = self.getActiveImage()
- if activeImage is not None:
- xOrigin = activeImage.getOrigin()[0]
- xScale = activeImage.getScale()[0]
-
- minValue = xOrigin + xScale * self._cache['dataXMin']
-
- if eventDict['x'] >= minValue:
- data = self._cache['histoH']
- column = int((eventDict['x'] - minValue) / xScale)
- if column >= 0 and column < data.shape[0]:
- self.valueChanged.emit(
- float('nan'),
- float(column + self._cache['dataXMin']),
- data[column])
+ yOrigin = activeImage.getOrigin()[1]
+ yScale = activeImage.getScale()[1]
- elif eventDict['event'] == 'limitsChanged':
- if (not self._updatingLimits and
- eventDict['xdata'] != self.getXAxis().getLimits()):
- xMin, xMax = eventDict['xdata']
- self.getXAxis().setLimits(xMin, xMax)
+ minValue = yOrigin + yScale * self._cache.dataYRange[0]
- def _histoVPlotCB(self, eventDict):
- """Callback for vertical histogram plot events."""
- if eventDict['event'] == 'mouseMoved':
- if self._cache is not None:
- activeImage = self.getActiveImage()
- if activeImage is not None:
- yOrigin = activeImage.getOrigin()[1]
- yScale = activeImage.getScale()[1]
-
- minValue = yOrigin + yScale * self._cache['dataYMin']
-
- if eventDict['y'] >= minValue:
- data = self._cache['histoV']
- row = int((eventDict['y'] - minValue) / yScale)
- if row >= 0 and row < data.shape[0]:
- self.valueChanged.emit(
- float(row + self._cache['dataYMin']),
- float('nan'),
- data[row])
-
- elif eventDict['event'] == 'limitsChanged':
- if (not self._updatingLimits and
- eventDict['ydata'] != self.getYAxis().getLimits()):
- yMin, yMax = eventDict['ydata']
- self.getYAxis().setLimits(yMin, yMax)
-
- def _radarViewCB(self, left, top, width, height):
- """Slot for radar view visible rectangle changes."""
- if not self._updatingLimits:
- # Takes care of Y axis conversion
- self.setLimits(left, left + width, top, top + height)
-
- def _updateYAxisInverted(self, inverted=None):
- """Sync image, vertical histogram and radar view axis orientation."""
- if inverted is None:
- # Do not perform this when called from plot signal
- inverted = self.getYAxis().isInverted()
-
- self._histoVPlot.getYAxis().setInverted(inverted)
-
- # Use scale to invert radarView
- # RadarView default Y direction is from top to bottom
- # As opposed to Plot. So invert RadarView when Plot is NOT inverted.
- self._radarView.resetTransform()
- if not inverted:
- self._radarView.scale(1., -1.)
- self._updateRadarView()
-
- self._radarView.update()
+ if y >= minValue:
+ data = self._cache.histoV
+ row = int((y - minValue) / yScale)
+ if row >= 0 and row < data.shape[0]:
+ self.valueChanged.emit(
+ float(row + self._cache.dataYRange[0]),
+ float('nan'),
+ data[row])
def _activeImageChangedSlot(self, previous, legend):
"""Handle Plot active image change.
@@ -620,6 +549,53 @@ class ImageView(PlotWindow):
self._dirtyCache()
self._updateHistograms()
+ def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]):
+ """Set where profile widgets are displayed.
+
+ :param ProfileWindowBehavior behavior:
+ - 'popup': All profiles are displayed in pop-up windows
+ - 'embedded': Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+ behavior = self.ProfileWindowBehavior.from_value(behavior)
+ if behavior is not self.getProfileWindowBehavior():
+ manager = self.__profile.getProfileManager()
+ manager.clearProfile()
+ manager.requestUpdateAllProfile()
+
+ if behavior is self.ProfileWindowBehavior.EMBEDDED:
+ horizontalProfileWindow = self._histoHPlot
+ verticalProfileWindow = self._histoVPlot
+ else:
+ horizontalProfileWindow = None
+ verticalProfileWindow = None
+
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageHorizontalLineROI, horizontalProfileWindow
+ )
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageVerticalLineROI, verticalProfileWindow
+ )
+ self.__profileWindowBehavior = behavior
+
+ def getProfileWindowBehavior(self) -> ProfileWindowBehavior:
+ """Returns current profile display behavior.
+
+ See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior`
+ """
+ return self.__profileWindowBehavior
+
+ def getProfileToolBar(self):
+ """"Returns profile tools attached to this plot.
+
+ :rtype: silx.gui.plot.PlotTools.ProfileToolBar
+ """
+ return self.__profile
+
+ @property
+ def profile(self):
+ return self.getProfileToolBar()
+
def getHistogram(self, axis):
"""Return the histogram and corresponding row or column extent.
@@ -639,12 +615,12 @@ class ImageView(PlotWindow):
else:
if axis == 'x':
return dict(
- data=numpy.array(self._cache['histoH'], copy=True),
- extent=(self._cache['dataXMin'], self._cache['dataXMax']))
+ data=numpy.array(self._cache.histoH, copy=True),
+ extent=self._cache.dataXRange)
else:
return dict(
- data=numpy.array(self._cache['histoV'], copy=True),
- extent=(self._cache['dataYMin'], self._cache['dataYMax']))
+ data=numpy.array(self._cache.histoV, copy=True),
+ extent=(self._cache.dataYRange))
def radarView(self):
"""Get the lower right radarView widget."""
@@ -656,13 +632,10 @@ class ImageView(PlotWindow):
:param RadarView radarView: Widget subclassing RadarView to replace
the lower right corner widget.
"""
- self._radarView.visibleRectDragged.disconnect(self._radarViewCB)
self._radarView = radarView
- self._radarView.visibleRectDragged.connect(self._radarViewCB)
+ self._radarView.setPlotWidget(self)
self.centralWidget().layout().addWidget(self._radarView, 1, 1)
- self._updateYAxisInverted()
-
# High-level API
def getColormap(self):
@@ -782,7 +755,6 @@ class ImageView(PlotWindow):
data = numpy.array(image, order='C', copy=copy)
assert data.size != 0
assert len(data.shape) == 2
- height, width = data.shape
self.addImage(data,
legend=self._imageLegend,
@@ -791,16 +763,8 @@ class ImageView(PlotWindow):
resetzoom=False)
self.setActiveImage(self._imageLegend)
self._updateHistograms()
-
- self._radarView.setDataRect(origin[0],
- origin[1],
- width * scale[0],
- height * scale[1])
-
if reset:
self.resetZoom()
- else:
- self._updateHistogramsLimits()
# ImageViewMainWindow #########################################################
@@ -839,16 +803,22 @@ class ImageViewMainWindow(ImageView):
menu.addAction(actions.control.KeepAspectRatioAction(self, self))
menu.addAction(actions.control.YAxisInvertedAction(self, self))
- menu = self.menuBar().addMenu('Profile')
- menu.addAction(self.profile.hLineAction)
- menu.addAction(self.profile.vLineAction)
- menu.addAction(self.profile.crossAction)
- menu.addAction(self.profile.lineAction)
- menu.addAction(self.profile.clearAction)
+ self.__profileMenu = self.menuBar().addMenu('Profile')
+ self.__updateProfileMenu()
# Connect to ImageView's signal
self.valueChanged.connect(self._statusBarSlot)
+ def __updateProfileMenu(self):
+ """Update actions available in 'Profile' menu"""
+ profile = self.getProfileToolBar()
+ self.__profileMenu.clear()
+ self.__profileMenu.addAction(profile.hLineAction)
+ self.__profileMenu.addAction(profile.vLineAction)
+ self.__profileMenu.addAction(profile.crossAction)
+ self.__profileMenu.addAction(profile.lineAction)
+ self.__profileMenu.addAction(profile.clearAction)
+
def _statusBarSlot(self, row, column, value):
"""Update status bar with coordinates/value from plots."""
if numpy.isnan(row):
@@ -863,11 +833,13 @@ class ImageViewMainWindow(ImageView):
self.statusBar().showMessage(msg)
- def setImage(self, image, *args, **kwargs):
- """Set the displayed image.
+ @docstring(ImageView)
+ def setProfileWindowBehavior(self, behavior: str):
+ super().setProfileWindowBehavior(behavior)
+ self.__updateProfileMenu()
- See :meth:`ImageView.setImage` for details.
- """
+ @docstring(ImageView)
+ def setImage(self, image, *args, **kwargs):
if hasattr(image, 'dtype') and hasattr(image, 'shape'):
assert len(image.shape) == 2
height, width = image.shape
diff --git a/silx/gui/plot/LegendSelector.py b/silx/gui/plot/LegendSelector.py
index 0ea0fc8..94112aa 100755
--- a/silx/gui/plot/LegendSelector.py
+++ b/silx/gui/plot/LegendSelector.py
@@ -524,11 +524,49 @@ class LegendListView(qt.QListView):
self.setContextMenu(contextMenu)
def setLegendList(self, legendList, row=None):
- self.clear()
- if row is None:
- row = 0
- model = self.model()
- model.insertLegendList(row, legendList)
+ if row is not None:
+ model = self.model()
+ model.insertLegendList(row, legendList)
+ elif len(legendList) != self.model().rowCount():
+ self.clear()
+ model = self.model()
+ model.insertLegendList(0, legendList)
+ else:
+ model = self.model()
+ for i, (new_legend, icon) in enumerate(legendList):
+ modelIndex = model.index(i)
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ if new_legend != legend:
+ model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
+
+ color = modelIndex.data(LegendModel.iconColorRole)
+ new_color = icon.get('color', None)
+ if new_color != color:
+ model.setData(modelIndex, new_color, LegendModel.iconColorRole)
+
+ linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ new_linewidth = icon.get('linewidth', 1.0)
+ if new_linewidth != linewidth:
+ model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole)
+
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ new_linestyle = icon.get('linestyle', None)
+ visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
+ model.setData(modelIndex, visible, LegendModel.showLineRole)
+ if new_linestyle != linestyle:
+ model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole)
+
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ new_symbol = icon.get('symbol', None)
+ visible = not LegendIconWidget.isEmptySymbol(new_symbol)
+ model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ if new_symbol != symbol:
+ model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
+
+ selected = modelIndex.data(qt.Qt.CheckStateRole)
+ new_selected = icon.get('selected', True)
+ if new_selected != selected:
+ model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
_logger.debug('LegendListView.setLegendList(legendList) finished')
def clear(self):
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
index 8ff8641..1ec1e7f 100644
--- a/silx/gui/plot/MaskToolsWidget.py
+++ b/silx/gui/plot/MaskToolsWidget.py
@@ -32,11 +32,9 @@ This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
"""
from __future__ import division
-
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "15/02/2019"
-
+__date__ = "08/12/2020"
import os
import sys
@@ -53,16 +51,15 @@ from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDo
from . import items
from ..colors import cursorColorForColormap, rgba
from .. import qt
+from ..utils import LockReentrant
from silx.third_party.EdfFile import EdfFile
from silx.third_party.TiffIO import TiffIO
import fabio
-
_logger = logging.getLogger(__name__)
-
_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
@@ -91,6 +88,7 @@ class ImageMask(BaseMask):
This is meant for internal use by :class:`MaskToolsWidget`.
"""
+
def __init__(self, image=None):
"""
@@ -193,7 +191,7 @@ class ImageMask(BaseMask):
selection = self._mask[max(0, row):row + height + 1,
max(0, col):col + width + 1]
if mask:
- selection[:, :] = level
+ selection[:,:] = level
else:
selection[selection == level] = 0
self._notify()
@@ -289,6 +287,38 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._z = 1 # Mask layer in plot
self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
+ self.__itemMaskUpdatedLock = LockReentrant()
+ self.__itemMaskUpdated = False
+
+ def __maskStateChanged(self) -> None:
+ """Handle mask commit to update item mask"""
+ item = self._mask.getDataItem()
+ if item is not None:
+ with self.__itemMaskUpdatedLock:
+ item.setMaskData(self._mask.getMask(copy=True), copy=False)
+
+ def setItemMaskUpdated(self, enabled: bool) -> None:
+ """Toggle item mask and mask tool synchronisation.
+
+ :param bool enabled: True to synchronise. Default: False
+ """
+ enabled = bool(enabled)
+ if enabled != self.__itemMaskUpdated:
+ if self.__itemMaskUpdated:
+ self._mask.sigStateChanged.disconnect(self.__maskStateChanged)
+ self.__itemMaskUpdated = enabled
+ if self.__itemMaskUpdated:
+ # Synchronize item and tool mask
+ self._setMaskedImage(self._mask.getDataItem())
+ self._mask.sigStateChanged.connect(self.__maskStateChanged)
+
+ def isItemMaskUpdated(self) -> bool:
+ """Returns whether or not item and mask tool masks are synchronised.
+
+ :rtype: bool
+ """
+ return self.__itemMaskUpdated
+
def setSelectionMask(self, mask, copy=True):
"""Set the mask to a new array.
@@ -319,13 +349,6 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if numpy.array_equal(mask, self.getSelectionMask()):
return mask.shape
- # ensure all mask attributes are synchronized with the active image
- # and connect listener
- activeImage = self.plot.getActiveImage()
- if activeImage is not None and activeImage.getName() != self._maskName:
- self._activeImageChanged()
- self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
-
if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
self._mask.setMask(mask, copy=copy)
self._mask.commit()
@@ -339,7 +362,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
dtype=numpy.uint8)
height = min(self._data.shape[0], mask.shape[0])
width = min(self._data.shape[1], mask.shape[1])
- resizedMask[:height, :width] = mask[:height, :width]
+ resizedMask[:height,:width] = mask[:height,:width]
self._mask.setMask(resizedMask, copy=False)
self._mask.commit()
return resizedMask.shape
@@ -374,7 +397,9 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._activeImageChangedAfterCare)
except (RuntimeError, TypeError):
pass
- self._activeImageChanged() # Init mask + enable/disable widget
+
+ # Sync with current active image
+ self._setMaskedImage(self.plot.getActiveImage())
self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
def hideEvent(self, event):
@@ -383,14 +408,41 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._activeImageChanged)
except (RuntimeError, TypeError):
pass
+
+ image = self.getMaskedItem()
+ if image is not None:
+ try:
+ image.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO should not happen
+
if self.isMaskInteractionActivated():
# Disable drawing tool
self.browseAction.trigger()
- if self.getSelectionMask(copy=False) is not None:
+ if self.isItemMaskUpdated(): # No "after-care"
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ elif self.getSelectionMask(copy=False) is not None:
self.plot.sigActiveImageChanged.connect(
self._activeImageChangedAfterCare)
+ def _activeImageChanged(self, previous, current):
+ """Reacts upon active image change.
+
+ Only handle change of active image items here.
+ """
+ if previous != current:
+ image = self.plot.getActiveImage()
+ if image is not None and image.getName() == self._maskName:
+ image = None # Active image is the mask
+ self._setMaskedImage(image)
+
def _setOverlayColorForImage(self, image):
"""Set the color of overlay adapted to image
@@ -443,41 +495,93 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._mask.setDataItem(activeImage)
self._updatePlotMask()
- def _activeImageChanged(self, *args):
- """Update widget and mask according to active image changes"""
- activeImage = self.plot.getActiveImage()
- if (activeImage is None or activeImage.getName() == self._maskName or
- activeImage.getData(copy=False).size == 0):
- # No active image or active image is the mask or image has no data...
+ def _setMaskedImage(self, image):
+ """Change the image that is used a reference to author the mask"""
+ previous = self.getMaskedItem()
+ if previous is not None and self.isVisible():
+ # Disconnect from previous image
+ try:
+ previous.sigItemChanged.disconnect(self.__imageChanged)
+ except TypeError:
+ pass # TODO fixme should not happen
+
+ # Set the image
+ self._mask.setDataItem(image)
+
+ if image is None: # No image, disable mask
self.setEnabled(False)
self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
self._mask.reset()
self._mask.commit()
- else: # There is an active image
- self.setEnabled(True)
+ self._updateInteractiveMode()
+
+ else: # Update and connect to image's sigItemChanged
+ if self.isItemMaskUpdated():
+ if image.getMaskData(copy=False) is None:
+ # Image item has no mask: use current mask from the tool
+ image.setMaskData(
+ self.getSelectionMask(copy=False), copy=True)
+ else: # Image item has a mask: set it in tool
+ self.setSelectionMask(
+ image.getMaskData(copy=False), copy=True)
+ self._mask.resetHistory()
+ self.__imageUpdated()
+ if self.isVisible():
+ image.sigItemChanged.connect(self.__imageChanged)
+
+ def __imageChanged(self, event):
+ """Reacts upon image item changes"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("Mask is not attached to an image")
+ return
- self._setOverlayColorForImage(activeImage)
+ if event in (items.ItemChangedType.COLORMAP,
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.ZVALUE):
+ self.__imageUpdated()
+
+ elif (event == items.ItemChangedType.MASK and
+ self.isItemMaskUpdated() and
+ not self.__itemMaskUpdatedLock.locked()):
+ # Update mask from the image item unless mask tool is updating it
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
+
+ def __imageUpdated(self):
+ """Synchronize mask with current state of the image"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("No active image while expecting one")
+ return
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
+ self._setOverlayColorForImage(image)
- self._origin = activeImage.getOrigin()
- self._scale = activeImage.getScale()
- self._z = activeImage.getZValue() + 1
- self._data = activeImage.getData(copy=False)
- self._mask.setDataItem(activeImage)
- if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
- self._mask.reset(self._data.shape[:2])
- self._mask.commit()
- else:
- # Refresh in case origin, scale, z changed
- self._updatePlotMask()
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = image.getOrigin()
+ self._scale = image.getScale()
+ self._z = image.getZValue() + 1
+ self._data = image.getData(copy=False)
+ self._mask.setDataItem(image)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data.shape[:2])
+ self._mask.commit()
+ else:
+ # Refresh in case origin, scale, z changed
+ self._updatePlotMask()
+
+ # Visible and with data
+ self.setEnabled(image.isVisible() and self._data.size != 0)
- # Threshold tools only available for data with colormap
- self.thresholdGroup.setEnabled(self._data.ndim == 2)
+ # Threshold tools only available for data with colormap
+ self.thresholdGroup.setEnabled(self._data.ndim == 2)
self._updateInteractiveMode()
@@ -809,6 +913,7 @@ class MaskToolsDockWidget(BaseMaskToolsDockWidget):
:param plot: The PlotWidget this widget is operating on
:paran str name: The title of this widget
"""
+
def __init__(self, parent=None, plot=None, name='Mask'):
widget = MaskToolsWidget(plot=plot)
super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
index 23b7fe9..2a211de 100755
--- a/silx/gui/plot/PlotWidget.py
+++ b/silx/gui/plot/PlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -42,6 +42,7 @@ from collections import OrderedDict, namedtuple
from contextlib import contextmanager
import datetime as dt
import itertools
+import typing
import warnings
import numpy
@@ -84,6 +85,166 @@ _PlotDataRange = namedtuple('PlotDataRange',
['x', 'y', 'yright'])
+class _PlotWidgetSelection(qt.QObject):
+ """Object managing a :class:`PlotWidget` selection.
+
+ It is a wrapper over :class:`PlotWidget`'s active items API.
+
+ :param PlotWidget parent:
+ """
+
+ sigCurrentItemChanged = qt.Signal(object, object)
+ """This signal is emitted whenever the current item changes.
+
+ It provides the current and previous items.
+ """
+
+ sigSelectedItemsChanged = qt.Signal()
+ """Signal emitted whenever the list of selected items changes."""
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(_PlotWidgetSelection, self).__init__(parent=parent)
+
+ # Init history
+ self.__history = [ # Store active items from most recent to oldest
+ item for item in (parent.getActiveCurve(),
+ parent.getActiveImage(),
+ parent.getActiveScatter())
+ if item is not None]
+
+ self.__current = self.__mostRecentActiveItem()
+
+ parent.sigActiveImageChanged.connect(self._activeImageChanged)
+ parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def __mostRecentActiveItem(self) -> typing.Optional[items.Item]:
+ """Returns most recent active item."""
+ return self.__history[0] if len(self.__history) >= 1 else None
+
+ def getSelectedItems(self) -> typing.Tuple[items.Item]:
+ """Returns the list of currently selected items in the :class:`PlotWidget`.
+
+ The list is given from most recently current item to oldest one."""
+ plot = self.parent()
+ if plot is None:
+ return ()
+
+ active = tuple(self.__history)
+
+ current = self.getCurrentItem()
+ if current is not None and current not in active:
+ # Current might not be an active item, if so add it
+ active = (current,) + active
+
+ return active
+
+ def getCurrentItem(self) -> typing.Optional[items.Item]:
+ """Returns the current item in the :class:`PlotWidget` or None. """
+ return self.__current
+
+ def setCurrentItem(self, item: typing.Optional[items.Item]):
+ """Set the current item in the :class:`PlotWidget`.
+
+ :param item:
+ The new item to select or None to clear the selection.
+ :raise ValueError: If the item is not the :class:`PlotWidget`
+ """
+ previous = self.getCurrentItem()
+ if previous is item:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ if item is None:
+ self.__current = None
+
+ # Reset all PlotWidget active items
+ plot = self.parent()
+ if plot is not None:
+ for kind in PlotWidget._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) is not None:
+ plot._setActiveItem(kind, None)
+
+ elif isinstance(item, items.Item):
+ plot = self.parent()
+ if plot is None or item.getPlot() is not plot:
+ raise ValueError(
+ "Item is not in the PlotWidget: %s" % str(item))
+ self.__current = item
+
+ kind = plot._itemKind(item)
+
+ # Clean-up history to be safe
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Sync active item if needed
+ if (kind in plot._ACTIVE_ITEM_KINDS and
+ item is not plot._getActiveItem(kind)):
+ plot._setActiveItem(kind, item.getName())
+ else:
+ raise ValueError("Not an Item: %s" % str(item))
+
+ self.sigCurrentItemChanged.emit(previous, item)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def __activeItemChanged(self,
+ kind: str,
+ previous: typing.Optional[str],
+ legend: typing.Optional[str]):
+ """Set current item from kind and legend"""
+ if previous == legend:
+ return # No-op for update of item
+
+ plot = self.parent()
+ if plot is None:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ # Remove items of this kind from the history
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Retrieve current item
+ if legend is None: # Use most recent active item
+ currentItem = self.__mostRecentActiveItem()
+ else:
+ currentItem = plot._getItem(kind=kind, legend=legend)
+ if currentItem is None: # Fallback in case something went wrong
+ currentItem = self.__mostRecentActiveItem()
+
+ # Update history
+ if currentItem is not None:
+ while currentItem in self.__history:
+ self.__history.remove(currentItem)
+ self.__history.insert(0, currentItem)
+
+ if currentItem != self.__current:
+ previousItem = self.__current
+ self.__current = currentItem
+ self.sigCurrentItemChanged.emit(previousItem, currentItem)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def _activeImageChanged(self, previous, current):
+ """Handle active image change"""
+ self.__activeItemChanged('image', previous, current)
+
+ def _activeCurveChanged(self, previous, current):
+ """Handle active curve change"""
+ self.__activeItemChanged('curve', previous, current)
+
+ def _activeScatterChanged(self, previous, current):
+ """Handle active scatter change"""
+ self.__activeItemChanged('scatter', previous, current)
+
+
class PlotWidget(qt.QMainWindow):
"""Qt Widget providing a 1D/2D plot.
@@ -313,6 +474,9 @@ class PlotWidget(qt.QMainWindow):
self._foregroundColorsUpdated()
self._backgroundColorsUpdated()
+ # selection handling
+ self.__selection = None
+
def __getBackendClass(self, backend):
"""Returns backend class corresponding to backend.
@@ -374,6 +538,12 @@ class PlotWidget(qt.QMainWindow):
raise ValueError("Backend not supported %s" % str(backend))
+ def selection(self):
+ """Returns the selection hander"""
+ if self.__selection is None: # Lazy initialization
+ self.__selection = _PlotWidgetSelection(parent=self)
+ return self.__selection
+
# TODO: Can be removed for silx 0.10
@staticmethod
@deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
@@ -849,6 +1019,21 @@ class PlotWidget(qt.QMainWindow):
self.notify('contentChanged', action='remove',
kind=kind, legend=item.getName())
+ def discardItem(self, item) -> bool:
+ """Remove the item from the plot.
+
+ Same as :meth:`removeItem` but do not raise an exception.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :returns: True if the item was present, False otherwise.
+ """
+ try:
+ self.removeItem(item)
+ except ValueError:
+ return False
+ else:
+ return True
+
@deprecated(replacement='addItem', since_version='0.13')
def _add(self, item):
return self.addItem(item)
@@ -910,8 +1095,8 @@ class PlotWidget(qt.QMainWindow):
:param numpy.ndarray y: The data corresponding to the y coordinates
:param str legend: The legend to be associated to the curve (or None)
:param info: User-defined information associated to the curve
- :param bool replace: True (the default) to delete already existing
- curves
+ :param bool replace: True to delete already existing curves
+ (the default is False)
:param color: color(s) to be used
:type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
one of the predefined color names defined in colors.py
diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py
index 8abddbe..7565155 100644
--- a/silx/gui/plot/Profile.py
+++ b/silx/gui/plot/Profile.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -57,16 +57,43 @@ class _CustomProfileManager(manager.ProfileManager):
if it is specified. Else the behavior is the same as the default
ProfileManager """
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__profileWindow = None
+ self.__specializedProfileWindows = {}
+
+ def setSpecializedProfileWindow(self, roiClass, profileWindow):
+ """Set a profile window for a given class or ROI.
+
+ Setting profileWindow to None removes the roiClass from the list.
+
+ :param roiClass:
+ :param profileWindow:
+ """
+ if profileWindow is None:
+ self.__specializedProfileWindows.pop(roiClass, None)
+ else:
+ self.__specializedProfileWindows[roiClass] = profileWindow
+
def setProfileWindow(self, profileWindow):
self.__profileWindow = profileWindow
def createProfileWindow(self, plot, roi):
+ for roiClass, specializedProfileWindow in self.__specializedProfileWindows.items():
+ if isinstance(roi, roiClass):
+ return specializedProfileWindow
+
if self.__profileWindow is not None:
return self.__profileWindow
else:
return super(_CustomProfileManager, self).createProfileWindow(plot, roi)
def clearProfileWindow(self, profileWindow):
+ for specializedProfileWindow in self.__specializedProfileWindows.values():
+ if profileWindow is specializedProfileWindow:
+ profileWindow.setProfile(None)
+ return
+
if self.__profileWindow is not None:
self.__profileWindow.setProfile(None)
else:
@@ -116,7 +143,7 @@ class ProfileToolBar(qt.QToolBar):
# If a profileWindow is defined,
# It will be used to display all the profiles
- self._manager = _CustomProfileManager(self, plot)
+ self._manager = self.createProfileManager(self, plot)
self._manager.setProfileWindow(profileWindow)
self._manager.setDefaultColorFromCursorColor(True)
self._manager.setItemType(image=True)
@@ -155,6 +182,9 @@ class ProfileToolBar(qt.QToolBar):
plot.sigActiveImageChanged.connect(self._activeImageChanged)
self._activeImageChanged()
+ def createProfileManager(self, parent, plot):
+ return _CustomProfileManager(parent, plot)
+
def _createProfileActions(self):
self.hLineAction = self._manager.createProfileAction(rois.ProfileImageHorizontalLineROI, self)
self.vLineAction = self._manager.createProfileAction(rois.ProfileImageVerticalLineROI, self)
diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py
index 26b48db..6d8739e 100644
--- a/silx/gui/plot/StatsWidget.py
+++ b/silx/gui/plot/StatsWidget.py
@@ -868,6 +868,12 @@ class StatsTable(_StatsWidgetBase, TableWidget):
statsHandler = self.getStatsHandler()
if statsHandler is not None:
+ # _updateStats is call when the plot visible area change.
+ # to force stats update we consider roi changed
+ if self._statsOnVisibleData:
+ roi_changed = True
+ else:
+ roi_changed = False
stats = statsHandler.calculate(
item, plot, self._statsOnVisibleData,
data_changed=data_changed, roi_changed=roi_changed)
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py
index 3298498..407ab11 100644
--- a/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -29,7 +29,7 @@ from __future__ import division
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "12/04/2019"
+__date__ = "08/12/2020"
import os
import weakref
@@ -60,6 +60,9 @@ class BaseMask(qt.QObject):
sigChanged = qt.Signal()
"""Signal emitted when the mask has changed"""
+ sigStateChanged = qt.Signal()
+ """Signal emitted for each mask commit/undo/redo operation"""
+
sigUndoable = qt.Signal(bool)
"""Signal emitted when undo becomes possible/impossible"""
@@ -81,7 +84,6 @@ class BaseMask(qt.QObject):
if dataItem is not None:
self.setDataItem(dataItem)
self.reset(self.getDataValues().shape)
-
super(BaseMask, self).__init__()
def setDataItem(self, item):
@@ -92,6 +94,13 @@ class BaseMask(qt.QObject):
"""
self._dataItem = item
+ def getDataItem(self):
+ """Returns current plot item the mask is on.
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._dataItem
+
def getDataValues(self):
"""Return data values, as a numpy array with the same shape
as the mask.
@@ -152,6 +161,7 @@ class BaseMask(qt.QObject):
if len(self._history) == 2:
self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
def undo(self):
"""Restore previous mask if any"""
@@ -164,6 +174,7 @@ class BaseMask(qt.QObject):
self.sigRedoable.emit(True)
if len(self._history) == 1: # Last value in history
self.sigUndoable.emit(False)
+ self.sigStateChanged.emit()
def redo(self):
"""Restore previously undone modification if any"""
@@ -176,8 +187,9 @@ class BaseMask(qt.QObject):
self.sigRedoable.emit(False)
if len(self._history) == 2: # Something to undo
self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
- # Whole mask operations
+ # Whole mask operations
def clear(self, level):
"""Set all values of the given mask level to 0.
@@ -211,7 +223,7 @@ class BaseMask(qt.QObject):
"""
if shape is None:
# assume dimensionality never changes
- shape = (0, ) * len(self._mask.shape) # empty array
+ shape = (0,) * len(self._mask.shape) # empty array
shapeChanged = (shape != self._mask.shape)
self._mask = numpy.zeros(shape, dtype=numpy.uint8)
if shapeChanged:
@@ -415,6 +427,13 @@ class BaseMaskToolsWidget(qt.QWidget):
"""Notify mask changes"""
self.sigMaskChanged.emit()
+ def getMaskedItem(self):
+ """Returns the item that is currently being masked
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._mask.getDataItem()
+
def getSelectionMask(self, copy=True):
"""Get the current mask as a numpy array.
@@ -935,11 +954,11 @@ class BaseMaskToolsWidget(qt.QWidget):
colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32)
# Set color
- colors[:, :3] = self._defaultOverlayColor[:3]
+ colors[:,:3] = self._defaultOverlayColor[:3]
# check if some colors has been directly set by the user
mask = numpy.equal(self._defaultColors, False)
- colors[mask, :3] = self._overlayColors[mask, :3]
+ colors[mask,:3] = self._overlayColors[mask,:3]
# Set alpha
colors[:, -1] = alpha / 2.
diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py
index 182ac78..439985e 100755
--- a/silx/gui/plot/actions/control.py
+++ b/silx/gui/plot/actions/control.py
@@ -374,22 +374,7 @@ class ColormapAction(PlotAction):
return
image = self.plot.getActiveImage()
- if isinstance(image, items.ImageComplexData):
- # Specific init for complex images
- colormap = image.getColormap()
-
- mode = image.getComplexMode()
- if mode in (items.ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
- items.ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE):
- data = image.getData(
- copy=False, mode=items.ImageComplexData.ComplexMode.PHASE)
- else:
- data = image.getData(copy=False)
-
- # Set histogram and range if any
- self._dialog.setData(data)
-
- elif isinstance(image, items.ColormapMixIn):
+ if isinstance(image, items.ColormapMixIn):
# Set dialog from active image
colormap = image.getColormap()
# Set histogram and range if any
diff --git a/silx/gui/plot/actions/histogram.py b/silx/gui/plot/actions/histogram.py
index f3e6370..0bba558 100644
--- a/silx/gui/plot/actions/histogram.py
+++ b/silx/gui/plot/actions/histogram.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,22 +34,238 @@ The following QAction are available:
from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
-__date__ = "10/10/2018"
+__date__ = "01/12/2020"
__license__ = "MIT"
import numpy
import logging
+import typing
import weakref
from .PlotToolAction import PlotToolAction
+
from silx.math.histogram import Histogramnd
from silx.math.combo import min_max
from silx.gui import qt
from silx.gui.plot import items
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
+class _ElidedLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ nbchar = max(len(self.getText()), 12)
+ width = self.fontMetrics().boundingRect('#' * nbchar).width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+class _StatWidget(qt.QWidget):
+ """Widget displaying a name and a value
+
+ :param parent:
+ :param name:
+ """
+
+ def __init__(self, parent=None, name: str=''):
+ super().__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ keyWidget = qt.QLabel(parent=self)
+ keyWidget.setText("<b>" + name.capitalize() + ":<b>")
+ layout.addWidget(keyWidget)
+ self.__valueWidget = _ElidedLabel(parent=self)
+ self.__valueWidget.setText("-")
+ self.__valueWidget.setTextInteractionFlags(
+ qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard)
+ layout.addWidget(self.__valueWidget)
+
+ def setValue(self, value: typing.Optional[float]):
+ """Set the displayed value
+
+ :param value:
+ """
+ self.__valueWidget.setText(
+ "-" if value is None else "{:.5g}".format(value))
+
+
+class HistogramWidget(qt.QWidget):
+ """Widget displaying a histogram and some statistic indicators"""
+
+ _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setWindowTitle('Histogram')
+
+ self.__itemRef = None # weakref on the item to track
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ # Plot
+ # Lazy import to avoid circular dependencies
+ from silx.gui.plot.PlotWindow import Plot1D
+ self.__plot = Plot1D(self)
+ layout.addWidget(self.__plot)
+
+ self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1)
+ self.__plot.getXAxis().setLabel("Value")
+ self.__plot.getYAxis().setLabel("Count")
+ posInfo = self.__plot.getPositionInfoWidget()
+ posInfo.setSnappingMode(posInfo.SNAPPING_CURVE)
+
+ # Stats display
+ statsWidget = qt.QWidget(self)
+ layout.addWidget(statsWidget)
+ statsLayout = qt.QHBoxLayout(statsWidget)
+ statsLayout.setContentsMargins(4, 4, 4, 4)
+
+ self.__statsWidgets = dict(
+ (name, _StatWidget(parent=statsWidget, name=name))
+ for name in ("min", "max", "mean", "std", "sum"))
+
+ for widget in self.__statsWidgets.values():
+ statsLayout.addWidget(widget)
+ statsLayout.addStretch(1)
+
+ def getPlotWidget(self):
+ """Returns :class:`PlotWidget` use to display the histogram"""
+ return self.__plot
+
+ def resetZoom(self):
+ """Reset PlotWidget zoom"""
+ self.getPlotWidget().resetZoom()
+
+ def reset(self):
+ """Clear displayed information"""
+ self.getPlotWidget().clear()
+ self.setStatistics()
+
+ def getItem(self) -> typing.Optional[items.Item]:
+ """Returns item used to display histogram and statistics."""
+ return None if self.__itemRef is None else self.__itemRef()
+
+ def setItem(self, item: typing.Optional[items.Item]):
+ """Set item from which to display histogram and statistics.
+
+ :param item:
+ """
+ previous = self.getItem()
+ if previous is not None:
+ previous.sigItemChanged.disconnect(self.__itemChanged)
+
+ self.__itemRef = None if item is None else weakref.ref(item)
+ if item is not None:
+ if isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ # Only listen signal for supported items
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._updateFromItem()
+
+ def __itemChanged(self, event):
+ """Handle update of the item"""
+ if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK):
+ self._updateFromItem()
+
+ def _updateFromItem(self):
+ """Update histogram and stats from the item"""
+ item = self.getItem()
+
+ if item is None:
+ self.reset()
+ return
+
+ if not isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ _logger.error("Unsupported item", item)
+ self.reset()
+ return
+
+ # Compute histogram and stats
+ array = item.getValueData(copy=False)
+
+ if array.size == 0:
+ self.reset()
+ return
+
+ xmin, xmax = min_max(array, min_positive=False, finite=True)
+ nbins = min(1024, int(numpy.sqrt(array.size)))
+ data_range = xmin, xmax
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(array.dtype, numpy.integer):
+ if nbins > xmax - xmin:
+ nbins = xmax - xmin
+
+ nbins = max(2, nbins)
+
+ data = array.ravel().astype(numpy.float32)
+ histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
+ if len(histogram.edges) != 1:
+ _logger.error("Error while computing the histogram")
+ self.reset()
+ return
+
+ self.setHistogram(histogram.histo, histogram.edges[0])
+ self.resetZoom()
+ self.setStatistics(
+ min_=xmin,
+ max_=xmax,
+ mean=numpy.nanmean(array),
+ std=numpy.nanstd(array),
+ sum_=numpy.nansum(array))
+
+ def setHistogram(self, histogram, edges):
+ """Set displayed histogram
+
+ :param histogram: Bin values (N)
+ :param edges: Bin edges (N+1)
+ """
+ self.getPlotWidget().addHistogram(
+ histogram=histogram,
+ edges=edges,
+ legend='histogram',
+ fill=True,
+ color='#66aad7',
+ resetzoom=False)
+
+ def getHistogram(self, copy: bool=True):
+ """Returns currently displayed histogram.
+
+ :param copy: True to get a copy,
+ False to get internal representation (Do not modify!)
+ :return: (histogram, edges) or None
+ """
+ for item in self.getPlotWidget().getItems():
+ if item.getName() == 'histogram':
+ return (item.getValueData(copy=copy),
+ item.getBinEdgesData(copy=copy))
+ else:
+ return None
+
+ def setStatistics(self,
+ min_: typing.Optional[float] = None,
+ max_: typing.Optional[float] = None,
+ mean: typing.Optional[float] = None,
+ std: typing.Optional[float] = None,
+ sum_: typing.Optional[float] = None):
+ """Set displayed statistic indicators."""
+ self.__statsWidgets['min'].setValue(min_)
+ self.__statsWidgets['max'].setValue(max_)
+ self.__statsWidgets['mean'].setValue(mean)
+ self.__statsWidgets['std'].setValue(std)
+ self.__statsWidgets['sum'].setValue(sum_)
+
+
class _LastActiveItem(qt.QObject):
sigActiveItemChanged = qt.Signal(object, object)
@@ -98,20 +314,26 @@ class _LastActiveItem(qt.QObject):
def _activeImageChanged(self, previous, current):
"""Handle active image change"""
plot = self.getPlotWidget()
- item = plot.getImage(current)
- if item is None:
- self.setActiveItem(None)
- elif isinstance(item, items.ImageBase):
- self.setActiveItem(item)
+ if current is None: # Fall-back to active scatter if any
+ self.setActiveItem(plot.getActiveScatter())
else:
- # Do not touch anything, which is consistent with silx v0.12 behavior
- pass
+ item = plot.getImage(current)
+ if item is None:
+ self.setActiveItem(None)
+ elif isinstance(item, items.ImageBase):
+ self.setActiveItem(item)
+ else:
+ # Do not touch anything, which is consistent with silx v0.12 behavior
+ pass
def _activeScatterChanged(self, previous, current):
"""Handle active scatter change"""
plot = self.getPlotWidget()
- item = plot.getScatter(current)
- self.setActiveItem(item)
+ if current is None: # Fall-back to active image if any
+ self.setActiveItem(plot.getActiveImage())
+ else:
+ item = plot.getScatter(current)
+ self.setActiveItem(item)
class PixelIntensitiesHistoAction(PlotToolAction):
@@ -129,130 +351,42 @@ class PixelIntensitiesHistoAction(PlotToolAction):
tooltip='Compute image intensity distribution',
parent=parent)
self._lastItemFilter = _LastActiveItem(self, plot)
- self._histo = None
- self._item = None
def _connectPlot(self, window):
self._lastItemFilter.sigActiveItemChanged.connect(self._activeItemChanged)
item = self._lastItemFilter.getActiveItem()
- self._setSelectedItem(item)
+ self.getHistogramWidget().setItem(item)
PlotToolAction._connectPlot(self, window)
def _disconnectPlot(self, window):
self._lastItemFilter.sigActiveItemChanged.disconnect(self._activeItemChanged)
PlotToolAction._disconnectPlot(self, window)
- self._setSelectedItem(None)
-
- def _getSelectedItem(self):
- item = self._item
- if item is None:
- return None
- else:
- return item()
+ self.getHistogramWidget().setItem(None)
def _activeItemChanged(self, previous, current):
if self._isWindowInUse():
- self._setSelectedItem(current)
-
- def _setSelectedItem(self, item):
- if item is not None:
- if not isinstance(item, (items.ImageBase, items.Scatter)):
- # Filter out other things
- return
-
- old = self._getSelectedItem()
- if item is old:
- return
- if old is not None:
- old.sigItemChanged.disconnect(self._itemUpdated)
- if item is None:
- self._item = None
- else:
- self._item = weakref.ref(item)
- item.sigItemChanged.connect(self._itemUpdated)
- self.computeIntensityDistribution()
-
- def _itemUpdated(self, event):
- if event == items.ItemChangedType.DATA:
- self.computeIntensityDistribution()
-
- def _cleanUp(self):
- plot = self.getHistogramPlotWidget()
- try:
- plot.remove('pixel intensity', kind='item')
- except Exception:
- pass
+ self.getHistogramWidget().setItem(current)
+ @deprecated(since_version='0.15.0')
def computeIntensityDistribution(self):
- """Get the active image and compute the image intensity distribution
- """
- item = self._getSelectedItem()
-
- if item is None:
- self._cleanUp()
- return
-
- if isinstance(item, items.ImageBase):
- array = item.getData(copy=False)
- if array.ndim == 3: # RGB(A) images
- _logger.info('Converting current image from RGB(A) to grayscale\
- in order to compute the intensity distribution')
- array = (array[:, :, 0] * 0.299 +
- array[:, :, 1] * 0.587 +
- array[:, :, 2] * 0.114)
- elif isinstance(item, items.Scatter):
- array = item.getValueData(copy=False)
- else:
- assert(False)
-
- if array.size == 0:
- self._cleanUp()
- return
-
- xmin, xmax = min_max(array, min_positive=False, finite=True)
- nbins = min(1024, int(numpy.sqrt(array.size)))
- data_range = xmin, xmax
-
- # bad hack: get 256 bins in the case we have a B&W
- if numpy.issubdtype(array.dtype, numpy.integer):
- if nbins > xmax - xmin:
- nbins = xmax - xmin
+ self.getHistogramWidget()._updateFromItem()
- nbins = max(2, nbins)
-
- data = array.ravel().astype(numpy.float32)
- histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
- assert len(histogram.edges) == 1
- self._histo = histogram.histo
- edges = histogram.edges[0]
- plot = self.getHistogramPlotWidget()
- plot.addHistogram(histogram=self._histo,
- edges=edges,
- legend='pixel intensity',
- fill=True,
- color='#66aad7')
- plot.resetZoom()
+ def getHistogramWidget(self):
+ """Returns the widget displaying the histogram"""
+ return self._getToolWindow()
+ @deprecated(since_version='0.15.0',
+ replacement='getHistogramWidget().getPlotWidget()')
def getHistogramPlotWidget(self):
- """Create the plot histogram if needed, otherwise create it
-
- :return: the PlotWidget showing the histogram of the pixel intensities
- """
- return self._getToolWindow()
+ return self._getToolWindow().getPlotWidget()
def _createToolWindow(self):
- from silx.gui.plot.PlotWindow import Plot1D
- window = Plot1D(parent=self.plot)
- window.setWindowFlags(qt.Qt.Window)
- window.setWindowTitle('Image Intensity Histogram')
- window.setDataMargins(0.1, 0.1, 0.1, 0.1)
- window.getXAxis().setLabel("Value")
- window.getYAxis().setLabel("Count")
- return window
-
- def getHistogram(self):
+ return HistogramWidget(self.plot, qt.Qt.Window)
+
+ def getHistogram(self) -> typing.Optional[numpy.ndarray]:
"""Return the last computed histogram
- :return: the histogram displayed in the HistogramPlotWiget
+ :return: the histogram displayed in the HistogramWidget
"""
- return self._histo
+ histogram = self.getHistogramWidget().getHistogram()
+ return None if histogram is None else histogram[0]
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
index 140672f..432b0b0 100755
--- a/silx/gui/plot/backends/BackendMatplotlib.py
+++ b/silx/gui/plot/backends/BackendMatplotlib.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -288,11 +288,17 @@ class _TextWithOffset(Text):
yoffset = 0
trans = self.get_transform()
- invtrans = self.get_transform().inverted()
-
x = super(_TextWithOffset, self).convert_xunits(self._x)
y = super(_TextWithOffset, self).convert_xunits(self._y)
pos = x, y
+
+ try:
+ invtrans = trans.inverted()
+ except numpy.linalg.LinAlgError:
+ # Cannot inverse transform, fallback: pos without offset
+ self.__cache = None
+ return pos
+
proj = trans.transform_point(pos)
proj = proj + numpy.array((xoffset, yoffset))
pos = invtrans.transform_point(proj)
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
index 909d18a..6fde9df 100755
--- a/silx/gui/plot/backends/BackendOpenGL.py
+++ b/silx/gui/plot/backends/BackendOpenGL.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -990,7 +990,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
item.getYAxis() == 'right')
self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
- self._glGarbageCollector.append(item)
+ if item.isInitialized():
+ self._glGarbageCollector.append(item)
elif isinstance(item, (_MarkerItem, _ShapeItem)):
pass # No-op
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
index c4e2c1e..34844c6 100644
--- a/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -235,12 +235,15 @@ class _Fill2D(object):
def discard(self):
"""Release VBOs"""
- if self._xFillVboData is not None:
+ if self.isInitialized():
self._xFillVboData.vbo.discard()
self._xFillVboData = None
self._yFillVboData = None
+ def isInitialized(self):
+ return self._xFillVboData is not None
+
# line ########################################################################
@@ -1061,13 +1064,16 @@ class _ErrorBars(object):
def discard(self):
"""Release VBOs"""
- if self._attribs is not None:
+ if self.isInitialized():
self._lines.xVboData, self._lines.yVboData = None, None
self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
self._attribs[0].vbo.discard()
self._attribs = None
+ def isInitialized(self):
+ return self._attribs is not None
+
# curves ######################################################################
@@ -1272,6 +1278,11 @@ class GLPlotCurve2D(GLPlotItem):
if self.fill is not None:
self.fill.discard()
+ def isInitialized(self):
+ return (self.xVboData is not None or
+ self._errorBars.isInitialized() or
+ (self.fill is not None and self.fill.isInitialized()))
+
def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
"""Perform picking on the curve according to its rendering.
diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py
index f60a159..3ad94b9 100644
--- a/silx/gui/plot/backends/glutils/GLPlotImage.py
+++ b/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -286,6 +286,10 @@ class GLPlotColormap(_GLPlotData2D):
self._texture = None
self._textureIsDirty = False
+ def isInitialized(self):
+ return (self._cmap_texture is not None or
+ self._texture is not None)
+
@property
def cmapRange(self):
if self.normalization == 'log':
@@ -622,11 +626,14 @@ class GLPlotRGBAImage(_GLPlotData2D):
return self._alpha
def discard(self):
- if self._texture is not None:
+ if self.isInitialized():
self._texture.discard()
self._texture = None
self._textureIsDirty = False
+ def isInitialized(self):
+ return self._texture is not None
+
def updateData(self, data):
assert data.dtype in self._SUPPORTED_DTYPES
oldData = self.data
diff --git a/silx/gui/plot/backends/glutils/GLPlotItem.py b/silx/gui/plot/backends/glutils/GLPlotItem.py
index 899f38e..ae13091 100644
--- a/silx/gui/plot/backends/glutils/GLPlotItem.py
+++ b/silx/gui/plot/backends/glutils/GLPlotItem.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -92,3 +92,8 @@ class GLPlotItem:
def discard(self):
"""Discards OpenGL resources this item has created."""
pass
+
+ def isInitialized(self) -> bool:
+ """Returns True if resources where initialized and requires `discard`.
+ """
+ return True
diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/silx/gui/plot/backends/glutils/GLPlotTriangles.py
index d5ba1a6..fbe9e02 100644
--- a/silx/gui/plot/backends/glutils/GLPlotTriangles.py
+++ b/silx/gui/plot/backends/glutils/GLPlotTriangles.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2019-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -144,12 +144,15 @@ class GLPlotTriangles(GLPlotItem):
def discard(self):
"""Release resources on the GPU"""
- if self.__vbos is not None:
+ if self.isInitialized():
self.__vbos[0].vbo.discard()
self.__vbos = None
self.__indicesVbo.discard()
self.__indicesVbo = None
+ def isInitialized(self):
+ return self.__vbos is not None
+
def prepare(self):
"""Allocate resources on the GPU"""
if self.__vbos is None:
diff --git a/silx/gui/plot/items/_arc_roi.py b/silx/gui/plot/items/_arc_roi.py
index a22cc3d..23416ec 100644
--- a/silx/gui/plot/items/_arc_roi.py
+++ b/silx/gui/plot/items/_arc_roi.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,6 +29,7 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "28/06/2018"
+import logging
import numpy
from ... import utils
@@ -40,6 +41,9 @@ from ._roi_base import InteractionModeMixIn
from ._roi_base import RoiInteractionMode
+logger = logging.getLogger(__name__)
+
+
class _ArcGeometry:
"""
Non-mutable object to store the geometry of the arc ROI.
@@ -779,8 +783,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
If `startAngle` is smaller than `endAngle` the rotation is clockwise,
else the rotation is anticlockwise.
"""
- assert innerRadius <= outerRadius
- assert numpy.abs(startAngle - endAngle) <= 2 * numpy.pi
+ if innerRadius > outerRadius:
+ logger.error("inner radius larger than outer radius")
+ innerRadius, outerRadius = outerRadius, innerRadius
center = numpy.array(center)
radius = (innerRadius + outerRadius) * 0.5
weight = outerRadius - innerRadius
diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py
index 0e492a0..abb64ad 100644
--- a/silx/gui/plot/items/complex.py
+++ b/silx/gui/plot/items/complex.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -184,18 +184,18 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
def setComplexMode(self, mode):
changed = super(ImageComplexData, self).setComplexMode(mode)
if changed:
+ self._valueDataChanged()
+
# Backward compatibility
self._updated(ItemChangedType.VISUALIZATION_MODE)
- # Send data updated as value returned by getData has changed
- self._updated(ItemChangedType.DATA)
-
# Update ColormapMixIn colormap
colormap = self._colormaps[self.getComplexMode()]
if colormap is not super(ImageComplexData, self).getColormap():
super(ImageComplexData, self).setColormap(colormap)
- self._setColormappedData(self.getData(copy=False), copy=False)
+ # Send data updated as value returned by getData has changed
+ self._updated(ItemChangedType.DATA)
return changed
def _setAmplitudeRangeInfo(self, max_=None, delta=2):
@@ -263,10 +263,32 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
'Image is not complex, converting it to complex to plot it.')
data = numpy.array(data, dtype=numpy.complex64)
- self._dataByModesCache = {}
- self._setColormappedData(self.getData(copy=False), copy=False)
+ # Compute current mode data and set colormap data
+ mode = self.getComplexMode()
+ dataForMode = self.__convertComplexData(data, self.getComplexMode())
+ self._dataByModesCache = {mode: dataForMode}
+
super().setData(data)
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ # ItemChangedType.COMPLEX_MODE triggers ItemChangedType.DATA
+ # No need to handle it twice.
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ # Color-mapped data is NOT the `getValueData` for some modes
+ if self.getComplexMode() in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ data = self.getData(copy=False, mode=self.ComplexMode.PHASE)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ data = numpy.copy(data)
+ data[mask != 0] = numpy.nan
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
def getComplexData(self, copy=True):
"""Returns the image complex data
@@ -276,6 +298,31 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
"""
return super().getData(copy=copy)
+ def __convertComplexData(self, data, mode):
+ """Convert complex data to given mode.
+
+ :param numpy.ndarray data:
+ :param Union[ComplexMode,str] mode:
+ :rtype: numpy.ndarray of float
+ """
+ if mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode in (self.ComplexMode.ABSOLUTE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ self.ComplexMode.AMPLITUDE_PHASE):
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ _logger.error(
+ 'Unsupported conversion mode: %s, fallback to absolute',
+ str(mode))
+ return numpy.absolute(data)
+
def getData(self, copy=True, mode=None):
"""Returns the image data corresponding to (current) mode.
@@ -295,27 +342,8 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
mode = self.ComplexMode.from_value(mode)
if mode not in self._dataByModesCache:
- # Compute data for mode and store it in cache
- complexData = self.getComplexData(copy=False)
- if mode is self.ComplexMode.PHASE:
- data = numpy.angle(complexData)
- elif mode is self.ComplexMode.REAL:
- data = numpy.real(complexData)
- elif mode is self.ComplexMode.IMAGINARY:
- data = numpy.imag(complexData)
- elif mode in (self.ComplexMode.ABSOLUTE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE,
- self.ComplexMode.AMPLITUDE_PHASE):
- data = numpy.absolute(complexData)
- elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
- data = numpy.absolute(complexData) ** 2
- else:
- _logger.error(
- 'Unsupported conversion mode: %s, fallback to absolute',
- str(mode))
- data = numpy.absolute(complexData)
-
- self._dataByModesCache[mode] = data
+ self._dataByModesCache[mode] = self.__convertComplexData(
+ self.getComplexData(copy=False), mode)
return numpy.array(self._dataByModesCache[mode], copy=copy)
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
index edc6d89..95a65ad 100644
--- a/silx/gui/plot/items/core.py
+++ b/silx/gui/plot/items/core.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "29/01/2019"
+__date__ = "08/12/2020"
import collections
try:
@@ -110,6 +110,9 @@ class ItemChangedType(enum.Enum):
DATA = 'dataChanged'
"""Item's data changed flag"""
+ MASK = 'maskChanged'
+ """Item's mask changed flag"""
+
HIGHLIGHTED = 'highlightedChanged'
"""Item's highlight state changed flag."""
@@ -315,7 +318,7 @@ class Item(qt.QObject):
info = deepcopy(info)
self._info = info
- def getVisibleBounds(self) -> Optional[Tuple[float,float,float,float]]:
+ def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]:
"""Returns visible bounds of the item bounding box in the plot area.
:returns:
@@ -503,9 +506,9 @@ class DataItem(Item):
self._boundsChanged(checkVisibility=False)
super().setVisible(visible)
-
# Mix-in classes ##############################################################
+
class ItemMixInBase(object):
"""Base class for Item mix-in"""
@@ -1232,7 +1235,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
def __init__(self):
self.__visualization = self.Visualization.POINTS
- self.__parameters = dict( # Init parameters to None
+ self.__parameters = dict(# Init parameters to None
(parameter, None) for parameter in self.VisualizationParameter)
self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean'
@@ -1404,8 +1407,8 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
elif error.ndim == 1: # N array
newError = numpy.empty((2, len(value)),
dtype=numpy.float64)
- newError[0, :] = error
- newError[1, :] = error
+ newError[0,:] = error
+ newError[1,:] = error
error = newError
elif error.size == 2 * len(value): # 2xN array
@@ -1610,14 +1613,32 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
assert len(x) == len(y)
assert x.ndim == y.ndim == 1
+ # Convert complex data
+ if numpy.iscomplexobj(x):
+ _logger.warning(
+ 'Converting x data to absolute value to plot it.')
+ x = numpy.absolute(x)
+ if numpy.iscomplexobj(y):
+ _logger.warning(
+ 'Converting y data to absolute value to plot it.')
+ y = numpy.absolute(y)
+
if xerror is not None:
if isinstance(xerror, abc.Iterable):
xerror = numpy.array(xerror, copy=copy)
+ if numpy.iscomplexobj(xerror):
+ _logger.warning(
+ 'Converting xerror data to absolute value to plot it.')
+ xerror = numpy.absolute(xerror)
else:
xerror = float(xerror)
if yerror is not None:
if isinstance(yerror, abc.Iterable):
yerror = numpy.array(yerror, copy=copy)
+ if numpy.iscomplexobj(yerror):
+ _logger.warning(
+ 'Converting yerror data to absolute value to plot it.')
+ yerror = numpy.absolute(yerror)
else:
yerror = float(yerror)
# TODO checks on xerror, yerror
@@ -1634,6 +1655,7 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
class BaselineMixIn(object):
"""Base class for Baseline mix-in"""
+
def __init__(self, baseline=None):
self._baseline = baseline
diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py
index 5941cc6..16bbefa 100644
--- a/silx/gui/plot/items/histogram.py
+++ b/silx/gui/plot/items/histogram.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,6 +30,7 @@ __license__ = "MIT"
__date__ = "28/08/2018"
import logging
+import typing
import numpy
from collections import OrderedDict, namedtuple
@@ -38,8 +39,10 @@ try:
except ImportError: # Python2 support
import collections as abc
+from ....utils.proxy import docstring
from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn,
- LineMixIn, YAxisMixIn, ItemChangedType)
+ LineMixIn, YAxisMixIn, ItemChangedType, Item)
+from ._pick import PickingResult
_logger = logging.getLogger(__name__)
@@ -219,6 +222,53 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
min(0, numpy.nanmin(values)),
max(0, numpy.nanmax(values)))
+ def __pickFilledHistogram(self, x: float, y: float) -> typing.Optional[PickingResult]:
+ """Picking implementation for filled histogram
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ """
+ if not self.isFill():
+ return None
+
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xData, yData = plot.pixelToData(x, y, axis=self.getYAxis())
+ xmin, xmax, ymin, ymax = self.getBounds()
+ if not xmin < xData < xmax or not ymin < yData < ymax:
+ return None # Outside bounding box
+
+ # Check x
+ edges = self.getBinEdgesData(copy=False)
+ index = numpy.searchsorted(edges, (xData,), side='left')[0] - 1
+ # Safe indexing in histogram values
+ index = numpy.clip(index, 0, len(edges) - 2)
+
+ # Check y
+ baseline = self.getBaseline(copy=False)
+ if baseline is None:
+ baseline = 0 # Default value
+
+ value = self.getValueData(copy=False)[index]
+ if ((baseline <= value and baseline <= yData <= value) or
+ (value < baseline and value <= yData <= baseline)):
+ return PickingResult(self, numpy.array([index]))
+ else:
+ return None
+
+ @docstring(DataItem)
+ def pick(self, x, y):
+ if self.isFill():
+ return self.__pickFilledHistogram(x, y)
+ else:
+ result = super().pick(x, y)
+ if result is None:
+ return None
+ else: # Convert from curve indices to histogram indices
+ return PickingResult(self, numpy.unique(result.getIndices() // 2))
+
def getValueData(self, copy=True):
"""The values of the histogram
diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py
index fda4245..0d9c9a4 100644
--- a/silx/gui/plot/items/image.py
+++ b/silx/gui/plot/items/image.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,8 +28,7 @@ of the :class:`Plot`.
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "20/10/2017"
-
+__date__ = "08/12/2020"
try:
from collections import abc
@@ -43,7 +42,6 @@ from ....utils.proxy import docstring
from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
AlphaMixIn, ItemChangedType)
-
_logger = logging.getLogger(__name__)
@@ -80,8 +78,8 @@ def _convertImageToRgba32(image, copy=True):
if image.shape[-1] == 3:
new_image = numpy.empty((image.shape[0], image.shape[1], 4),
dtype=numpy.uint8)
- new_image[:, :, :3] = image
- new_image[:, :, 3] = 255
+ new_image[:,:,:3] = image
+ new_image[:,:, 3] = 255
return new_image # This is a copy anyway
else:
return numpy.array(image, copy=copy)
@@ -93,7 +91,7 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
:param numpy.ndarray data: Initial image data
"""
- def __init__(self, data=None):
+ def __init__(self, data=None, mask=None):
DataItem.__init__(self)
LabelsMixIn.__init__(self)
DraggableMixIn.__init__(self)
@@ -101,7 +99,8 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
if data is None:
data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
self._data = data
-
+ self._mask = mask
+ self.__valueDataCache = None # Store default data
self._origin = (0., 0.)
self._scale = (1., 1.)
@@ -186,13 +185,98 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
:param numpy.ndarray data:
"""
+ previousShape = self._data.shape
self._data = data
+ self._valueDataChanged()
self._boundsChanged()
self._updated(ItemChangedType.DATA)
+ if (self.getMaskData(copy=False) is not None and
+ previousShape != self._data.shape):
+ # Data shape changed, so mask shape changes.
+ # Send event, mask is lazily updated in getMaskData
+ self._updated(ItemChangedType.MASK)
+
+ def getMaskData(self, copy=True):
+ """Returns the mask data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._mask is None:
+ return None
+
+ # Update mask if it does not match data shape
+ shape = self.getData(copy=False).shape[:2]
+ if self._mask.shape != shape:
+ # Clip/extend mask to match data
+ newMask = numpy.zeros(shape, dtype=self._mask.dtype)
+ newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]]
+ self._mask = newMask
+
+ return numpy.array(self._mask, copy=copy)
+
+ def setMaskData(self, mask, copy=True):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ :param bool copy: True (Default) to make a copy,
+ False to use as is (do not modify!)
+ """
+ if mask is not None:
+ mask = numpy.array(mask, copy=copy)
+
+ shape = self.getData(copy=False).shape[:2]
+ if mask.shape != shape:
+ _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape)
+ # Clip/extent is done lazily in getMaskData
+ elif self._mask is None:
+ return # No update
+
+ self._mask = mask
+ self._valueDataChanged()
+ self._updated(ItemChangedType.MASK)
+
+ def _valueDataChanged(self):
+ """Clear cache of default data array"""
+ self.__valueDataCache = None
+
+ def _getValueData(self, copy=True):
+ """Return data used by :meth:`getValueData`
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ return self.getData(copy=copy)
+
+ def getValueData(self, copy=True):
+ """Return data (converted to int or float) with mask applied.
+
+ Masked values are set to Not-A-Number.
+ It returns a 2D array of values (int or float).
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ if self.__valueDataCache is None:
+ data = self._getValueData(copy=False)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ dtype = data.dtype
+ else:
+ dtype = numpy.float64
+ data = numpy.array(data, dtype=dtype, copy=True)
+ data[mask != 0] = numpy.NaN
+ self.__valueDataCache = data
+ return numpy.array(self.__valueDataCache, copy=copy)
+
def getRgbaImageData(self, copy=True):
"""Get the displayed RGB(A) image
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
:returns: numpy.ndarray of uint8 of shape (height, width, 4)
"""
raise NotImplementedError('This MUST be implemented in sub-class')
@@ -308,7 +392,7 @@ class ImageData(ImageBase, ColormapMixIn):
alphaImage = self.getAlphaData(copy=False)
if alphaImage is not None:
# Apply transparency
- image[:, :, 3] = image[:, :, 3] * alphaImage
+ image[:,:, 3] = image[:,:, 3] * alphaImage
return image
def getAlternativeImageData(self, copy=True):
@@ -358,7 +442,6 @@ class ImageData(ImageBase, ColormapMixIn):
_logger.warning(
'Converting complex image to absolute value to plot it.')
data = numpy.absolute(data)
- self._setColormappedData(data, copy=False)
if alternative is not None:
alternative = numpy.array(alternative, copy=copy)
@@ -378,6 +461,14 @@ class ImageData(ImageBase, ColormapMixIn):
super().setData(data)
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ self._setColormappedData(
+ self.getValueData(copy=False),
+ copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
class ImageRgba(ImageBase):
"""Description of an RGB(A) image"""
@@ -423,6 +514,20 @@ class ImageRgba(ImageBase):
assert data.shape[-1] in (3, 4)
super().setData(data)
+ def _getValueData(self, copy=True):
+ """Compute the intensity of the RGBA image as default data.
+
+ Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
+
+ :param bool copy:
+ """
+ rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
+ intensity = (rgba[:, :, 0] * 0.299 +
+ rgba[:, :, 1] * 0.587 +
+ rgba[:, :, 2] * 0.114)
+ intensity *= rgba[:, :, 3] / 255.
+ return intensity
+
class MaskImageData(ImageData):
"""Description of an image used as a mask.
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index fd7cfae..2d54223 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -935,6 +935,12 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
assert value.ndim == 1
assert len(x) == len(value)
+ # Convert complex data
+ if numpy.iscomplexobj(value):
+ _logger.warning(
+ 'Converting value data to absolute value to plot it.')
+ value = numpy.absolute(value)
+
# Reset triangulation and interpolator
if self.__delaunayFuture is not None:
self.__delaunayFuture.cancel()
diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py
index 755b185..a81f7bb 100644
--- a/silx/gui/plot/stats/stats.py
+++ b/silx/gui/plot/stats/stats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -193,19 +193,6 @@ class _StatsContext(object):
self.clipData(item, plot, onlimits, roi=roi)
- def clipData(self, item, plot, onlimits, roi):
- """
- Clip the data to the current mask to have accurate statistics
-
- :param item: item for whiwh we want to clip data
- :param plot: plot containing the item
- :param onlimits: do we want to apply statistic only on
- visible data.
- :param roi: Region of interest for computing the statistics.
- :type roi: Union[None,:class:`_RegionOfInterestBase`]
- """
- raise NotImplementedError()
-
def clear_mask(self):
"""
Remove the mask to force recomputation of it on next iteration
@@ -232,7 +219,8 @@ class _StatsContext(object):
raise NotImplementedError("Base class")
def clipData(self, item, plot, onlimits, roi):
- """
+ """Clip the data to the current mask to have accurate statistics
+
Function called before computing each statistics associated to this
context. It will insure the context for the (item, plot, onlimits, roi)
is created.
@@ -340,9 +328,8 @@ class _CurveContext(_ScatterCurveHistoMixInContext):
mask = self.mask
else:
mask = (minX <= xData) & (xData <= maxX)
- yData = yData[mask]
- xData = xData[mask]
- mask = numpy.zeros_like(yData)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
elif roi:
minX, maxX = roi.getFrom(), roi.getTo()
if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
@@ -350,10 +337,11 @@ class _CurveContext(_ScatterCurveHistoMixInContext):
else:
mask = (minX <= xData) & (xData <= maxX)
mask = mask == 0
- mask = mask.astype(numpy.int32)
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
else:
mask = numpy.zeros_like(yData)
+ mask = mask.astype(numpy.uint32)
self.xData = xData
self.yData = yData
self.values = numpy.ma.array(yData, mask=mask)
@@ -363,7 +351,6 @@ class _CurveContext(_ScatterCurveHistoMixInContext):
else:
self.min, self.max = None, None
self.data = (xData, yData)
-
self.axes = (xData,)
def _checkContextInputs(self, item, plot, onlimits, roi):
@@ -399,38 +386,33 @@ class _HistogramContext(_ScatterCurveHistoMixInContext):
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- if self.is_mask_valid(onlimits, from_=minX, to_=maxX):
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
mask = self.mask
else:
mask = (minX <= xData) & (xData <= maxX)
- self._set_mask_validity(onlimits=True, from_=minX, to_=maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
elif roi:
- if self.is_mask_valid(onlimits, from_=roi._fromdata, to_=roi._todata):
+ if self.is_mask_valid(onlimits=onlimits, from_=roi._fromdata, to_=roi._todata):
mask = self.mask
else:
mask = (roi._fromdata <= xData) & (xData <= roi._todata)
mask = mask == 0
- self._set_mask_validity(onlimits=True, from_=roi._fromdata,
+ self._set_mask_validity(onlimits=onlimits, from_=roi._fromdata,
to_=roi._todata)
else:
- mask = numpy.zeros_like(self.data)
-
- if onlimits:
- yData = yData[mask]
- xData = xData[mask]
-
- self.data = (xData, yData)
- self.values = numpy.ma.array(yData, mask=mask)
- self.axes = (xData,)
-
+ mask = numpy.zeros_like(yData)
+ mask = mask.astype(numpy.uint32)
self.xData = xData
self.yData = yData
-
+ self.values = numpy.ma.array(yData, mask=(mask))
unmasked_data = self.values.compressed()
if len(unmasked_data) > 0:
self.min, self.max = min_max(unmasked_data)
else:
self.min, self.max = None, None
+ self.data = (self.xData, self.yData)
+ self.axes = (self.xData,)
def _checkContextInputs(self, item, plot, onlimits, roi):
_StatsContext._checkContextInputs(self, item=item, plot=plot,
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
index 2e8db55..c22975f 100644
--- a/silx/gui/plot/test/testMaskToolsWidget.py
+++ b/silx/gui/plot/test/testMaskToolsWidget.py
@@ -136,6 +136,15 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.mouseRelease(
plot, qt.Qt.LeftButton, pos=star[-1])
+ def _isMaskItemSync(self):
+ """Check if masks from item and tools are sync or not"""
+ if self.maskWidget.isItemMaskUpdated():
+ return numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(),
+ self.plot.getActiveImage().getMaskData(copy=False)))
+ else:
+ return True
+
def testWithAnImage(self):
"""Plot with an image: test MaskToolsWidget interactions"""
@@ -152,80 +161,91 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
((0, 0), (-1, -1)),
((1000, 1000), (-1, -1))]
- for origin, scale in tests:
- with self.subTest(origin=origin, scale=scale):
- self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
- legend='test',
- origin=origin,
- scale=scale)
- self.qapp.processEvents()
-
- # Test draw rectangle #
- toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drag()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drag()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test draw polygon #
- toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test draw pencil #
- toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.maskWidget.pencilSpinBox.setValue(30)
- self.qapp.processEvents()
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test no draw tool #
- toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.plot.clear()
+ for itemMaskUpdated in (False, True):
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test',
+ origin=origin,
+ scale=scale)
+ self.qapp.processEvents()
+
+ self.assertEqual(
+ self.maskWidget.isItemMaskUpdated(), itemMaskUpdated)
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
def __loadSave(self, file_format):
"""Plot with an image: test MaskToolsWidget operations"""
diff --git a/silx/gui/plot/test/testPixelIntensityHistoAction.py b/silx/gui/plot/test/testPixelIntensityHistoAction.py
index 882f496..ac29952 100644
--- a/silx/gui/plot/test/testPixelIntensityHistoAction.py
+++ b/silx/gui/plot/test/testPixelIntensityHistoAction.py
@@ -65,7 +65,7 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
self.mouseMove(button)
self.mouseClick(button, qt.Qt.LeftButton)
self.qapp.processEvents()
- self.assertTrue(histoAction.getHistogramPlotWidget().isVisible())
+ self.assertTrue(histoAction.getHistogramWidget().isVisible())
# test the pixel intensity diagram is hiding
self.qapp.setActiveWindow(self.plotImage)
@@ -73,7 +73,7 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
self.mouseMove(button)
self.mouseClick(button, qt.Qt.LeftButton)
self.qapp.processEvents()
- self.assertFalse(histoAction.getHistogramPlotWidget().isVisible())
+ self.assertFalse(histoAction.getHistogramWidget().isVisible())
def testImageFormatInput(self):
"""Test multiple type as image input"""
@@ -108,9 +108,9 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
self.mouseClick(button, qt.Qt.LeftButton)
self.qapp.processEvents()
- plot = histoAction.getHistogramPlotWidget()
- self.assertTrue(plot.isVisible())
- items = plot.getItems()
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
self.assertEqual(len(items), 1)
def testChangeItem(self):
@@ -131,9 +131,9 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
self.qapp.processEvents()
# Reach histogram from the first item
- plot = histoAction.getHistogramPlotWidget()
- self.assertTrue(plot.isVisible())
- items = plot.getItems()
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
data1 = items[0].getValueData(copy=False)
# Set another item to the plot
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
index f9d2281..b55260e 100755
--- a/silx/gui/plot/test/testPlotWidget.py
+++ b/silx/gui/plot/test/testPlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -197,6 +197,21 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
self.assertEqual(items[5].getType(), 'rectangle')
+ def testRemoveDiscardItem(self):
+ """Test removeItem and discardItem"""
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ self.plot.removeItem(curve)
+ with self.assertRaises(ValueError):
+ self.plot.removeItem(curve)
+
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ result = self.plot.discardItem(curve)
+ self.assertTrue(result)
+ result = self.plot.discardItem(curve)
+ self.assertFalse(result)
+
def testBackGroundColors(self):
self.plot.setVisible(True)
self.qWaitForWindowExposed(self.plot)
@@ -559,6 +574,11 @@ class TestPlotCurve(PlotWidgetTestCase):
self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
baseline=list(range(0, 100, 1)))
+ def testPlotCurveComplexData(self):
+ """Test curve with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
+
class TestPlotHistogram(PlotWidgetTestCase):
"""Basic tests for add Histogram"""
@@ -592,6 +612,13 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
self.plot.addScatter(x, y, value)
self.plot.resetZoom()
+ def testScatterComplexData(self):
+ """Test scatter item with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addScatter(
+ x=data, y=data, value=data, xerror=data, yerror=data)
+ self.plot.resetZoom()
+
def testScatterVisualization(self):
self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3))
self.plot.resetZoom()
@@ -1857,6 +1884,153 @@ class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
self.assertEqual(self.plot.getItems(), items)
+class TestPlotWidgetSelection(PlotWidgetTestCase):
+ """Test PlotWidget.selection and active items handling"""
+
+ def _checkSelection(self, selection, current=None, selected=()):
+ """Check current item and selected items."""
+ self.assertIs(selection.getCurrentItem(), current)
+ self.assertEqual(selection.getSelectedItems(), selected)
+
+ def testSyncWithActiveItems(self):
+ """Test update of PlotWidgetSelection according to active items"""
+ listener = SignalListener()
+
+ selection = self.plot.selection()
+ selection.sigCurrentItemChanged.connect(listener)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.assertEqual(listener.callCount(), 1)
+ self._checkSelection(selection, image, (image,))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.assertEqual(listener.callCount(), 2)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 3)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.assertEqual(listener.callCount(), 4)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # Previously mosted recently "actived" item is current
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 5)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 6)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 7)
+ self._checkSelection(selection)
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 8)
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ curve = self.plot.getCurve('curve')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveCurve('curve')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2')
+ curve2 = self.plot.getCurve('curve2')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Mosted recently "actived" item is current, previous curve is removed
+ self.plot.setActiveCurve('curve2')
+ self.assertEqual(listener.callCount(), 11)
+ self._checkSelection(selection, curve2, (curve2, image, scatter))
+
+ # No items = no current
+ self.plot.clear()
+ self.assertEqual(listener.callCount(), 12)
+ self._checkSelection(selection)
+
+ def testPlotWidgetWithItems(self):
+ """Test init of selection on a plot with items"""
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ selected = selection.getSelectedItems()
+ self.assertEqual(len(selected), 3)
+ self.assertIn(self.plot.getActiveCurve(), selected)
+ self.assertIn(self.plot.getActiveImage(), selected)
+ self.assertIn(self.plot.getActiveScatter(), selected)
+
+ def testSetCurrentItem(self):
+ """Test setCurrentItem"""
+ # Add items to the plot
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+ curve = self.plot.getActiveCurve()
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ self.assertEqual(len(selection.getSelectedItems()), 3)
+
+ # Set current to None reset all active items
+ selection.setCurrentItem(None)
+ self._checkSelection(selection)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIsNone(self.plot.getActiveImage())
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active
+ selection.setCurrentItem(image)
+ self._checkSelection(selection, image, (image,))
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(curve)
+ self._checkSelection(selection, curve, (curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(scatter)
+ self._checkSelection(selection, scatter, (scatter, curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIs(self.plot.getActiveScatter(), scatter)
+
+
def suite():
testClasses = (TestPlotWidget,
TestPlotImage,
@@ -1870,7 +2044,8 @@ def suite():
TestPlotEmptyLog,
TestPlotCurveLog,
TestPlotImageLog,
- TestPlotMarkerLog)
+ TestPlotMarkerLog,
+ TestPlotWidgetSelection)
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/tools/PositionInfo.py b/silx/gui/plot/tools/PositionInfo.py
index 4b63cdb..81d312a 100644
--- a/silx/gui/plot/tools/PositionInfo.py
+++ b/silx/gui/plot/tools/PositionInfo.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -44,11 +44,25 @@ import numpy
from ....utils.deprecation import deprecated
from ... import qt
from .. import items
+from ...widgets.ElidedLabel import ElidedLabel
_logger = logging.getLogger(__name__)
+class _PositionInfoLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ width = self.fontMetrics().boundingRect('##############').width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
# PositionInfo ################################################################
class PositionInfo(qt.QWidget):
@@ -117,11 +131,8 @@ class PositionInfo(qt.QWidget):
for name, func in converters:
layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
- contentWidget = qt.QLabel()
+ contentWidget = _PositionInfoLabel(self)
contentWidget.setText('------')
- contentWidget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
- contentWidget.setFixedWidth(
- contentWidget.fontMetrics().boundingRect('##############').width())
layout.addWidget(contentWidget)
self._fields.append((contentWidget, name, func))
@@ -213,10 +224,11 @@ class PositionInfo(qt.QWidget):
kinds = []
if snappingMode & self.SNAPPING_CURVE:
kinds.append(items.Curve)
+ kinds.append(items.Histogram)
if snappingMode & self.SNAPPING_SCATTER:
kinds.append(items.Scatter)
selectedItems = [item for item in plot.getItems()
- if isinstance(item, kinds) and item.isVisible()]
+ if isinstance(item, tuple(kinds)) and item.isVisible()]
# Compute distance threshold
if qt.BINDING in ('PyQt5', 'PySide2'):
@@ -233,38 +245,54 @@ class PositionInfo(qt.QWidget):
distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2
for item in selectedItems:
- if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and
- not item.getSymbol()):
+ if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
+ not isinstance(item, items.SymbolMixIn) or
+ not item.getSymbol())):
# Only handled if item symbols are visible
continue
- xArray = item.getXData(copy=False)
- yArray = item.getYData(copy=False)
- closestIndex = numpy.argmin(
- pow(xArray - x, 2) + pow(yArray - y, 2))
+ if isinstance(item, items.Histogram):
+ result = item.pick(xPixel, yPixel)
+ if result is not None: # Histogram picked
+ index = result.getIndices()[0]
+ edges = item.getBinEdgesData(copy=False)
- xClosest = xArray[closestIndex]
- yClosest = yArray[closestIndex]
+ # Snap to bin center and value
+ xData = 0.5 * (edges[index] + edges[index + 1])
+ yData = item.getValueData(copy=False)[index]
- if isinstance(item, items.YAxisMixIn):
- axis = item.getYAxis()
- else:
- axis = 'left'
-
- closestInPixels = plot.dataToPixel(
- xClosest, yClosest, axis=axis)
- if closestInPixels is not None:
- curveDistInPixels = (
- (closestInPixels[0] - xPixel)**2 +
- (closestInPixels[1] - yPixel)**2)
-
- if curveDistInPixels <= distInPixels:
# Update label style sheet
styleSheet = "color: rgb(0, 0, 0);"
-
- # if close enough, snap to data point coord
- xData, yData = xClosest, yClosest
- distInPixels = curveDistInPixels
+ break
+
+ else: # Curve, Scatter
+ xArray = item.getXData(copy=False)
+ yArray = item.getYData(copy=False)
+ closestIndex = numpy.argmin(
+ pow(xArray - x, 2) + pow(yArray - y, 2))
+
+ xClosest = xArray[closestIndex]
+ yClosest = yArray[closestIndex]
+
+ if isinstance(item, items.YAxisMixIn):
+ axis = item.getYAxis()
+ else:
+ axis = 'left'
+
+ closestInPixels = plot.dataToPixel(
+ xClosest, yClosest, axis=axis)
+ if closestInPixels is not None:
+ curveDistInPixels = (
+ (closestInPixels[0] - xPixel)**2 +
+ (closestInPixels[1] - yPixel)**2)
+
+ if curveDistInPixels <= distInPixels:
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+
+ # if close enough, snap to data point coord
+ xData, yData = xClosest, yClosest
+ distInPixels = curveDistInPixels
for label, name, func in self._fields:
label.setStyleSheet(styleSheet)
diff --git a/silx/gui/plot/tools/RadarView.py b/silx/gui/plot/tools/RadarView.py
new file mode 100644
index 0000000..7076835
--- /dev/null
+++ b/silx/gui/plot/tools/RadarView.py
@@ -0,0 +1,361 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying an overview of a 2D plot.
+
+This shows the available range of the data, and the current location of the
+plot view.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/02/2021"
+
+import logging
+import weakref
+from ... import qt
+from ...utils import LockReentrant
+
+_logger = logging.getLogger(__name__)
+
+
+class _DraggableRectItem(qt.QGraphicsRectItem):
+ """RectItem which signals its change through visibleRectDragged."""
+ def __init__(self, *args, **kwargs):
+ super(_DraggableRectItem, self).__init__(
+ *args, **kwargs)
+
+ self._previousCursor = None
+ self.setFlag(qt.QGraphicsItem.ItemIsMovable)
+ self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges)
+ self.setAcceptHoverEvents(True)
+ self._ignoreChange = False
+ self._constraint = 0, 0, 0, 0
+
+ def setConstraintRect(self, left, top, width, height):
+ """Set the constraint rectangle for dragging.
+
+ The coordinates are in the _DraggableRectItem coordinate system.
+
+ This constraint only applies to modification through interaction
+ (i.e., this constraint is not applied to change through API).
+
+ If the _DraggableRectItem is smaller than the constraint rectangle,
+ the _DraggableRectItem remains within the constraint rectangle.
+ If the _DraggableRectItem is wider than the constraint rectangle,
+ the constraint rectangle remains within the _DraggableRectItem.
+ """
+ self._constraint = left, left + width, top, top + height
+
+ def setPos(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).setPos(*args, **kwargs)
+ self._ignoreChange = False
+
+ def moveBy(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).moveBy(*args, **kwargs)
+ self._ignoreChange = False
+
+ def itemChange(self, change, value):
+ """Callback called before applying changes to the item."""
+ if (change == qt.QGraphicsItem.ItemPositionChange and
+ not self._ignoreChange):
+ # Makes sure that the visible area is in the data
+ # or that data is in the visible area if area is too wide
+ x, y = value.x(), value.y()
+ xMin, xMax, yMin, yMax = self._constraint
+
+ if self.rect().width() <= (xMax - xMin):
+ if x < xMin:
+ value.setX(xMin)
+ elif x > xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+ else:
+ if x > xMin:
+ value.setX(xMin)
+ elif x < xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+
+ if self.rect().height() <= (yMax - yMin):
+ if y < yMin:
+ value.setY(yMin)
+ elif y > yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+ else:
+ if y > yMin:
+ value.setY(yMin)
+ elif y < yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+
+ if self.pos() != value:
+ # Notify change through signal
+ views = self.scene().views()
+ assert len(views) == 1
+ views[0].visibleRectDragged.emit(
+ value.x() + self.rect().left(),
+ value.y() + self.rect().top(),
+ self.rect().width(),
+ self.rect().height())
+
+ return value
+
+ return super(_DraggableRectItem, self).itemChange(
+ change, value)
+
+ def hoverEnterEvent(self, event):
+ """Called when the mouse enters the rectangle area"""
+ self._previousCursor = self.cursor()
+ self.setCursor(qt.Qt.OpenHandCursor)
+
+ def hoverLeaveEvent(self, event):
+ """Called when the mouse leaves the rectangle area"""
+ if self._previousCursor is not None:
+ self.setCursor(self._previousCursor)
+ self._previousCursor = None
+
+
+class RadarView(qt.QGraphicsView):
+ """Widget presenting a synthetic view of a 2D area and
+ the current visible area.
+
+ Coordinates are as in QGraphicsView:
+ x goes from left to right and y goes from top to bottom.
+ This widget preserves the aspect ratio of the areas.
+
+ The 2D area and the visible area can be set with :meth:`setDataRect`
+ and :meth:`setVisibleRect`.
+ When the visible area has been dragged by the user, its new position
+ is signaled by the *visibleRectDragged* signal.
+
+ It is possible to invert the direction of the axes by using the
+ :meth:`scale` method of QGraphicsView.
+ """
+
+ visibleRectDragged = qt.Signal(float, float, float, float)
+ """Signals that the visible rectangle has been dragged.
+
+ It provides: left, top, width, height in data coordinates.
+ """
+
+ _DATA_PEN = qt.QPen(qt.QColor('white'))
+ _DATA_BRUSH = qt.QBrush(qt.QColor('light gray'))
+ _ACTIVEDATA_PEN = qt.QPen(qt.QColor('black'))
+ _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor('transparent'))
+ _ACTIVEDATA_PEN.setWidth(2)
+ _ACTIVEDATA_PEN.setCosmetic(True)
+ _VISIBLE_PEN = qt.QPen(qt.QColor('blue'))
+ _VISIBLE_PEN.setWidth(2)
+ _VISIBLE_PEN.setCosmetic(True)
+ _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
+ _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image'
+
+ _PIXMAP_SIZE = 256
+
+ def __init__(self, parent=None):
+ self.__plotRef = None
+ self._scene = qt.QGraphicsScene()
+ self._dataRect = self._scene.addRect(0, 0, 1, 1,
+ self._DATA_PEN,
+ self._DATA_BRUSH)
+ self._imageRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._imageRect.setVisible(False)
+ self._scatterRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._scatterRect.setVisible(False)
+ self._curveRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._curveRect.setVisible(False)
+
+ self._visibleRect = _DraggableRectItem(0, 0, 1, 1)
+ self._visibleRect.setPen(self._VISIBLE_PEN)
+ self._visibleRect.setBrush(self._VISIBLE_BRUSH)
+ self._scene.addItem(self._visibleRect)
+
+ super(RadarView, self).__init__(self._scene, parent)
+ self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ self.setStyleSheet('border: 0px')
+ self.setToolTip(self._TOOLTIP)
+
+ self.__reentrant = LockReentrant()
+ self.visibleRectDragged.connect(self._viewRectDragged)
+
+ self.__timer = qt.QTimer(self)
+ self.__timer.timeout.connect(self._updateDataContent)
+
+ def sizeHint(self):
+ # """Overridden to avoid sizeHint to depend on content size."""
+ return self.minimumSizeHint()
+
+ def wheelEvent(self, event):
+ # """Overridden to disable vertical scrolling with wheel."""
+ event.ignore()
+
+ def resizeEvent(self, event):
+ # """Overridden to fit current content to new size."""
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ super(RadarView, self).resizeEvent(event)
+
+ def setDataRect(self, left, top, width, height):
+ """Set the bounds of the data rectangular area.
+
+ This sets the coordinate system.
+ """
+ self._dataRect.setRect(left, top, width, height)
+ self._visibleRect.setConstraintRect(left, top, width, height)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def setVisibleRect(self, left, top, width, height):
+ """Set the visible rectangular area.
+
+ The coordinates are relative to the data rect.
+ """
+ self.__visibleRect = left, top, width, height
+ self._visibleRect.setRect(0, 0, width, height)
+ self._visibleRect.setPos(left, top)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def __setVisibleRectFromPlot(self, plot):
+ """Update radar view visible area.
+
+ Takes care of y coordinate conversion.
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ self.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin)
+
+ def getPlotWidget(self):
+ """Returns the connected plot
+
+ :rtype: Union[None,PlotWidget]
+ """
+ if self.__plotRef is None:
+ return None
+ plot = self.__plotRef()
+ if plot is None:
+ self.__plotRef = None
+ return plot
+
+ def setPlotWidget(self, plot):
+ """Set the PlotWidget this radar view connects to.
+
+ As result `setDataRect` and `setVisibleRect` will be called
+ automatically.
+
+ :param Union[None,PlotWidget] plot:
+ """
+ previousPlot = self.getPlotWidget()
+ if previousPlot is not None: # Disconnect previous plot
+ plot.getXAxis().sigLimitsChanged.disconnect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.disconnect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.disconnect(self._updateYAxisInverted)
+
+ # Reset plot and timer
+ # FIXME: It would be good to clean up the display here
+ self.__plotRef = None
+ self.__timer.stop()
+
+ if plot is not None: # Connect new plot
+ self.__plotRef = weakref.ref(plot)
+ plot.getXAxis().sigLimitsChanged.connect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.connect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.connect(self._updateYAxisInverted)
+ self.__setVisibleRectFromPlot(plot)
+ self._updateYAxisInverted()
+ self.__timer.start(500)
+
+ def _xLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _yLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _updateYAxisInverted(self, inverted=None):
+ """Sync radar view axis orientation."""
+ plot = self.getPlotWidget()
+ if inverted is None:
+ # Do not perform this when called from plot signal
+ inverted = plot.getYAxis().isInverted()
+ # Use scale to invert radarView
+ # RadarView default Y direction is from top to bottom
+ # As opposed to Plot. So invert RadarView when Plot is NOT inverted.
+ self.resetTransform()
+ if not inverted:
+ self.scale(1., -1.)
+ self.update()
+
+ def _viewRectDragged(self, left, top, width, height):
+ """Slot for radar view visible rectangle changes."""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ if self.__reentrant.locked():
+ return
+
+ with self.__reentrant:
+ plot.setLimits(left, left + width, top, top + height)
+
+ def _updateDataContent(self):
+ """Update the content to the current data content"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+ ranges = plot.getDataRange()
+ xmin, xmax = ranges.x if ranges.x is not None else (0, 0)
+ ymin, ymax = ranges.y if ranges.y is not None else (0, 0)
+ self.setDataRect(xmin, ymin, xmax - xmin, ymax - ymin)
+
+ self.__updateItem(self._imageRect, plot.getActiveImage())
+ self.__updateItem(self._scatterRect, plot.getActiveScatter())
+ self.__updateItem(self._curveRect, plot.getActiveCurve())
+
+ def __updateItem(self, rect, item):
+ """Sync rect with item bounds
+
+ :param QGraphicsRectItem rect:
+ :param Item item:
+ """
+ if item is None:
+ rect.setVisible(False)
+ return
+ ranges = item._getBounds()
+ if ranges is None:
+ rect.setVisible(False)
+ return
+ xmin, xmax, ymin, ymax = ranges
+ width = xmax - xmin
+ height = ymax - ymin
+ rect.setRect(xmin, ymin, width, height)
+ rect.setVisible(True)
diff --git a/silx/gui/plot/tools/profile/core.py b/silx/gui/plot/tools/profile/core.py
index 1f883dc..200f5cf 100644
--- a/silx/gui/plot/tools/profile/core.py
+++ b/silx/gui/plot/tools/profile/core.py
@@ -167,7 +167,10 @@ class ProfileRoiMixIn:
def __profileWindowAboutToClose(self):
profileManager = self.getProfileManager()
roiManager = profileManager.getRoiManager()
- roiManager.removeRoi(self)
+ try:
+ roiManager.removeRoi(self)
+ except ValueError:
+ pass
def computeProfile(self, item):
"""
diff --git a/silx/gui/plot/tools/profile/manager.py b/silx/gui/plot/tools/profile/manager.py
index 757b741..68db9a6 100644
--- a/silx/gui/plot/tools/profile/manager.py
+++ b/silx/gui/plot/tools/profile/manager.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -775,7 +775,8 @@ class ProfileManager(qt.QObject):
window = self._disconnectProfileWindow(profileRoi)
if window is not None:
geometry = window.geometry()
- self._previousWindowGeometry.append(geometry)
+ if not geometry.isEmpty():
+ self._previousWindowGeometry.append(geometry)
self.clearProfileWindow(window)
if profileRoi in self._rois:
self._rois.remove(profileRoi)
@@ -949,6 +950,7 @@ class ProfileManager(qt.QObject):
"""Handle item changes.
"""
if changeType in (items.ItemChangedType.DATA,
+ items.ItemChangedType.MASK,
items.ItemChangedType.POSITION,
items.ItemChangedType.SCALE):
self.requestUpdateAllProfile()
diff --git a/silx/gui/plot/tools/profile/rois.py b/silx/gui/plot/tools/profile/rois.py
index 9e651a7..eb7e975 100644
--- a/silx/gui/plot/tools/profile/rois.py
+++ b/silx/gui/plot/tools/profile/rois.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,7 +33,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "03/04/2020"
+__date__ = "01/12/2020"
import numpy
import weakref
@@ -137,11 +137,7 @@ class _ImageProfileArea(items.Shape):
if not isinstance(item, items.ImageBase):
raise TypeError("Unexpected class %s" % type(item))
- if isinstance(item, items.ImageRgba):
- rgba = item.getData(copy=False)
- currentData = rgba[..., 0]
- else:
- currentData = item.getData(copy=False)
+ currentData = item.getValueData(copy=False)
roi = self.getParentRoi()
origin = item.getOrigin()
@@ -288,7 +284,7 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
roiStart, roiEnd = self.getEndPoints()
else:
assert False
-
+
return roiStart, roiEnd, lineProjectionMode
def computeProfile(self, item):
@@ -310,15 +306,7 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
method=method)
return coords, profile, profileName, xLabel
- if isinstance(item, items.ImageRgba):
- rgba = item.getData(copy=False)
- is_uint8 = rgba.dtype.type == numpy.uint8
- # luminosity
- if is_uint8:
- rgba = rgba.astype(numpy.float64)
- currentData = 0.21 * rgba[..., 0] + 0.72 * rgba[..., 1] + 0.07 * rgba[..., 2]
- else:
- currentData = item.getData(copy=False)
+ currentData = item.getValueData(copy=False)
yLabel = "%s" % str(method).capitalize()
coords, profile, title, xLabel = createProfile2(currentData)
@@ -427,7 +415,7 @@ class ProfileImageDirectedLineROI(roi_items.LineROI,
scale = item.getScale()
method = self.getProfileMethod()
lineWidth = self.getProfileLineWidth()
- currentData = item.getData(copy=False)
+ currentData = item.getValueData(copy=False)
roiInfo = self._getRoiInfo()
roiStart, roiEnd, _lineProjectionMode = roiInfo
@@ -448,8 +436,8 @@ class ProfileImageDirectedLineROI(roi_items.LineROI,
method=method)
# Compute the line size
- lineSize = numpy.sqrt((roiEnd[1] - roiStart[1])**2 +
- (roiEnd[0] - roiStart[0])**2)
+ lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 +
+ (roiEnd[0] - roiStart[0]) ** 2)
coords = numpy.linspace(0, lineSize, len(profile),
endpoint=True,
dtype=numpy.float32)
diff --git a/silx/gui/plot3d/items/__init__.py b/silx/gui/plot3d/items/__init__.py
index 5810618..e7c4af1 100644
--- a/silx/gui/plot3d/items/__init__.py
+++ b/silx/gui/plot3d/items/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -37,7 +37,7 @@ from .core import ItemChangedType, Item3DChangedType # noqa
from .mixins import (ColormapMixIn, ComplexMixIn, InterpolationMixIn, # noqa
PlaneMixIn, SymbolMixIn) # noqa
from .clipplane import ClipPlane # noqa
-from .image import ImageData, ImageRgba # noqa
+from .image import ImageData, ImageRgba, HeightMapData, HeightMapRGBA # noqa
from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa
from .scatter import Scatter2D, Scatter3D # noqa
from .volume import ComplexField3D, ScalarField3D # noqa
diff --git a/silx/gui/plot3d/items/image.py b/silx/gui/plot3d/items/image.py
index cfd1188..4e2b396 100644
--- a/silx/gui/plot3d/items/image.py
+++ b/silx/gui/plot3d/items/image.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -174,3 +174,252 @@ class ImageRgba(_Image, InterpolationMixIn):
:return: The image data
"""
return self._image.getData(copy=copy)
+
+
+class _HeightMap(DataItem3D):
+ """Base class for 2D data array displayed as a height field.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _pickFull(self, context, threshold=0., sort='depth'):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :param float threshold: Picking threshold in pixel.
+ Perform picking in a square of size threshold x threshold.
+ :param str sort: How returned indices are sorted:
+
+ - 'index' (default): sort by the value of the indices
+ - 'depth': Sort by the depth of the points from the current
+ camera point of view.
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ assert sort in ('index', 'depth')
+
+ rayNdc = context.getPickingSegment(frame='ndc')
+ if rayNdc is None: # No picking outside viewport
+ return None
+
+ # TODO no colormapped or color data
+ # Project data to NDC
+ heightData = self.getData(copy=False)
+ if heightData.size == 0:
+ return # Nothing displayed
+
+ height, width = heightData.shape
+ z = numpy.ravel(heightData)
+ y, x = numpy.mgrid[0:height, 0:width]
+ dataPoints = numpy.transpose((numpy.ravel(x),
+ numpy.ravel(y),
+ z,
+ numpy.ones_like(z)))
+
+ primitive = self._getScenePrimitive()
+
+ pointsNdc = primitive.objectToNDCTransform.transformPoints(
+ dataPoints, perspectiveDivide=True)
+
+ # Perform picking
+ distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
+ # TODO issue with symbol size: using pixel instead of points
+ threshold += 1. # symbol size
+ thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
+ pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+
+ if sort == 'depth':
+ # Sort picked points from front to back
+ picked = picked[numpy.argsort(pointsNdc[picked, 2])]
+
+ if picked.size > 0:
+ # Convert indices from 1D to 2D
+ return PickingResult(self,
+ positions=dataPoints[picked, :3],
+ indices=(picked // width, picked % width),
+ fetchdata=self.getData)
+ else:
+ return None
+
+ def setData(self, data, copy: bool=True):
+ """Set the height field data.
+
+ :param data:
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the height field 2D data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapData(_HeightMap, ColormapMixIn):
+ """Description of a 2D height field associated to a colormapped dataset.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+ ColormapMixIn._setSceneColormap(self, None)
+
+ if not self.isVisible():
+ return # Update when visible
+
+ data = self.getColormappedData(copy=False)
+ heightData = self.getData(copy=False)
+
+ if data.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if data.shape != heightData.shape: # data and height size miss-match
+ # Colormapped data is interpolated (nearest-neighbour) to match the height field
+ data = data[numpy.floor(y * data.shape[0] / height).astype(numpy.int),
+ numpy.floor(x * data.shape[1] / height).astype(numpy.int)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.Points(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ value=numpy.ravel(data),
+ size=1)
+ primitive.marker = 's'
+ ColormapMixIn._setSceneColormap(self, primitive.colormap)
+ self._getScenePrimitive().children = [primitive]
+
+ def setColormappedData(self, data, copy: bool=True):
+ """Set the 2D data used to compute colors.
+
+ :param data: 2D array of data
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColormappedData(self, copy: bool=True) -> numpy.ndarray:
+ """Returns the 2D data used to compute colors.
+
+ :param copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapRGBA(_HeightMap):
+ """Description of a 2D height field associated to a RGB(A) image.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+
+ self.__rgba = numpy.zeros((0, 0, 3), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+
+ if not self.isVisible():
+ return # Update when visible
+
+ rgba = self.getColorData(copy=False)
+ heightData = self.getData(copy=False)
+ if rgba.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if rgba.shape[:2] != heightData.shape: # image and height size miss-match
+ # RGBA data is interpolated (nearest-neighbour) to match the height field
+ rgba = rgba[numpy.floor(y * rgba.shape[0] / height).astype(numpy.int),
+ numpy.floor(x * rgba.shape[1] / height).astype(numpy.int)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.ColorPoints(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ color=rgba.reshape(-1, rgba.shape[-1]),
+ size=1)
+ primitive.marker = 's'
+ self._getScenePrimitive().children = [primitive]
+
+ def setColorData(self, data, copy: bool=True):
+ """Set the RGB(A) image to use.
+
+ Supported array format: float32 in [0, 1], uint8.
+
+ :param data:
+ The RGBA image data as an array of shape (H, W, Channels)
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ # TODO check type
+
+ self.__rgba = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColorData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the RGB(A) image data.
+
+ :param copy: True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__rgba, copy=copy)
diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py
index b4c8e26..7f35c3c 100644
--- a/silx/gui/plot3d/scene/primitives.py
+++ b/silx/gui/plot3d/scene/primitives.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -1383,7 +1383,7 @@ class ColorPoints(_Points):
"""A set of points with an associated color and size."""
_ATTR_INFO = _Points._ATTR_INFO.copy()
- _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (4,)}})
+ _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (3, 4)}})
def __init__(self, x, y, z, color=(1., 1., 1., 1.), size=1.,
indices=None):
@@ -2041,7 +2041,10 @@ class _Image(Geometry):
$scenePreCall(vCameraPosition);
vec4 color = imageColor(data, vTexCoords);
- color.a = alpha;
+ color.a *= alpha;
+ if (color.a == 0.) { /* Discard fully transparent pixels */
+ discard;
+ }
vec3 normal = vec3(0.0, 0.0, 1.0);
gl_FragColor = $lightingCall(color, vPosition, normal);
diff --git a/silx/gui/plot3d/test/testSceneWindow.py b/silx/gui/plot3d/test/testSceneWindow.py
index b2e6ea0..8cf6b81 100644
--- a/silx/gui/plot3d/test/testSceneWindow.py
+++ b/silx/gui/plot3d/test/testSceneWindow.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -37,7 +37,7 @@ from silx.gui.utils.testutils import TestCaseQt
from silx.gui import qt
from silx.gui.plot3d.SceneWindow import SceneWindow
-
+from silx.gui.plot3d.items import HeightMapData, HeightMapRGBA
class TestSceneWindow(TestCaseQt, ParametricTestCase):
"""Tests SceneWidget picking feature"""
@@ -114,6 +114,42 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
sceneWidget.resetZoom('front')
self.qapp.processEvents()
+ def testHeightMap(self):
+ """Test height map items"""
+ sceneWidget = self.window.getSceneWidget()
+
+ height = numpy.arange(10000).reshape(100, 100) /100.
+
+ for shape in ((100, 100), (4, 5), (150, 20), (110, 110)):
+ with self.subTest(shape=shape):
+ items = []
+
+ # Colormapped data height map
+ data = numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
+
+ heightmap = HeightMapData()
+ heightmap.setData(height)
+ heightmap.setColormappedData(data)
+ heightmap.getColormap().setName('viridis')
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ # RGBA height map
+ colors = numpy.zeros(shape + (3,), dtype=numpy.float32)
+ colors[:, :, 1] = numpy.random.random(shape)
+
+ heightmap = HeightMapRGBA()
+ heightmap.setData(height)
+ heightmap.setColorData(colors)
+ heightmap.setTranslation(100., 0., 0.)
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+ sceneWidget.resetZoom('front')
+ self.qapp.processEvents()
+ sceneWidget.clearItems()
+
def testChangeContent(self):
"""Test add/remove/clear items"""
sceneWidget = self.window.getSceneWidget()
diff --git a/silx/gui/plot3d/tools/PositionInfoWidget.py b/silx/gui/plot3d/tools/PositionInfoWidget.py
index fc86a7f..78f2959 100644
--- a/silx/gui/plot3d/tools/PositionInfoWidget.py
+++ b/silx/gui/plot3d/tools/PositionInfoWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -144,6 +144,8 @@ class PositionInfoWidget(qt.QWidget):
items.Scatter2D,
items.ImageData,
items.ImageRgba,
+ items.HeightMapData,
+ items.HeightMapRGBA,
items.Mesh,
items.Box,
items.Cylinder,
diff --git a/silx/gui/qt/__init__.py b/silx/gui/qt/__init__.py
index b8c6cdd..ace2841 100644
--- a/silx/gui/qt/__init__.py
+++ b/silx/gui/qt/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -46,7 +46,7 @@ Example of using :mod:`silx.gui.qt` module:
>>> widget = qt.QWidget()
For an alternative solution providing a structured namespace,
-see `qtpy <https://pypi.python.org/pypi/QtPy/>`_ which
+see `qtpy <https://pypi.org/project/QtPy/>`_ which
provides the namespace of PyQt5 over PyQt4, PySide and PySide2.
"""
diff --git a/silx/gui/setup.py b/silx/gui/setup.py
index 6eb87ae..04a2bac 100644
--- a/silx/gui/setup.py
+++ b/silx/gui/setup.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -43,6 +43,7 @@ def configuration(parent_package='', top_path=None):
config.add_subpackage('data')
config.add_subpackage('dialog')
config.add_subpackage('utils')
+ config.add_subpackage('utils.glutils')
config.add_subpackage('utils.test')
return config
diff --git a/silx/gui/utils/glutils.py b/silx/gui/utils/glutils/__init__.py
index 83cfd89..c90f029 100644
--- a/silx/gui/utils/glutils.py
+++ b/silx/gui/utils/glutils/__init__.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,13 +27,6 @@
import os
import sys
-
-if __name__ == "__main__":
- # When run as a script, remove directory from sys.path
- # This avoids other script in same directory to override Python modules
- if os.path.abspath(sys.path[0]) == os.path.abspath(os.path.dirname(__file__)):
- sys.path.pop(0)
-
import subprocess
from silx.gui import qt
@@ -77,7 +70,7 @@ def _runtimeOpenGLCheck(version):
try:
error = subprocess.check_output(
- [sys.executable, __file__, major, minor],
+ [sys.executable, '-s', '-S', __file__, major, minor],
env=env,
timeout=2)
except subprocess.TimeoutExpired:
@@ -190,7 +183,7 @@ if __name__ == "__main__":
app = qt.QApplication([])
window = qt.QMainWindow(flags=
- qt.Qt.Window |
+ qt.Qt.Popup |
qt.Qt.FramelessWindowHint |
qt.Qt.NoDropShadowWindowHint |
qt.Qt.WindowStaysOnTopHint)
diff --git a/silx/gui/widgets/FlowLayout.py b/silx/gui/widgets/FlowLayout.py
index 14c8ab2..3c4c9dd 100644
--- a/silx/gui/widgets/FlowLayout.py
+++ b/silx/gui/widgets/FlowLayout.py
@@ -71,7 +71,7 @@ class FlowLayout(qt.QLayout):
return None
def expandingDirections(self):
- return 0
+ return qt.Qt.Orientations()
def hasHeightForWidth(self):
return True
diff --git a/silx/gui/widgets/UrlSelectionTable.py b/silx/gui/widgets/UrlSelectionTable.py
index 27ea363..fb15edd 100644
--- a/silx/gui/widgets/UrlSelectionTable.py
+++ b/silx/gui/widgets/UrlSelectionTable.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2017 European Synchrotron Radiation Facility
+# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -37,7 +37,7 @@ import functools
import logging
import os
-logger = logging.getLogger(__file__)
+logger = logging.getLogger(__name__)
class UrlSelectionTable(TableWidget):
diff --git a/silx/image/bilinear.pyx b/silx/image/bilinear.pyx
index 7f6354b..14547f8 100644
--- a/silx/image/bilinear.pyx
+++ b/silx/image/bilinear.pyx
@@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
+#cython: embedsignature=True, language_level=3
+## This is for optimisation
+#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False,
+## This is for developping:
+##cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True
+#
+#
#
# Project: silx (originally pyFAI)
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2017 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2020 European Synchrotron Radiation Facility, Grenoble, France
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,32 +30,49 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+"""Bilinear interpolator, peak finder, line-profile for images"""
__authors__ = ["J. Kieffer"]
__license__ = "MIT"
-__date__ = "15/09/2016"
-__doc__ = "Bilinear interpolator, peak finder, line-profile for images"
+__date__ = "26/11/2020"
+
+# C-level imports
+from libc.stdint cimport uint8_t
+from libc.math cimport floor, ceil, sqrt, NAN, isfinite
+from libc.float cimport FLT_MAX
import cython
-from cython.view cimport array as cvarray
import numpy
-from libc.math cimport floor, ceil, sin, cos, sqrt, atan2
import logging
logger = logging.getLogger(__name__)
+#Definition of some constants
+# How data are stored
+ctypedef float data_t
+data_d = numpy.float32
+
+#How the mask is stored
+ctypedef uint8_t mask_t
+mask_d = numpy.uint8
+
+
cdef class BilinearImage:
"""Bilinear interpolator for images ... or any data on a regular grid
"""
cdef:
- readonly float[:, ::1] data
- readonly float maxi, mini
- readonly size_t width, height
-
- cpdef size_t coarse_local_maxi(self, size_t)
- cdef size_t c_local_maxi(self, size_t) nogil
- cdef float c_funct(self, float, float) nogil
-
- def __cinit__(self, data not None):
+ readonly data_t[:, ::1] data
+ readonly mask_t[:, ::1] mask
+ readonly data_t maxi, mini
+ readonly Py_ssize_t width, height
+ readonly bint has_mask
+
+ # C-level declarations
+ cpdef Py_ssize_t coarse_local_maxi(self, Py_ssize_t)
+ cdef Py_ssize_t c_local_maxi(self, Py_ssize_t) nogil
+ cdef data_t c_funct(self, data_t, data_t) nogil
+ cdef void _init_min_max(self) nogil
+
+ def __cinit__(self, data not None, mask=None):
"""Constructor
:param data: image as a 2D array
@@ -56,12 +80,18 @@ cdef class BilinearImage:
assert data.ndim == 2
self.height = data.shape[0]
self.width = data.shape[1]
- self.maxi = data.max()
- self.mini = data.min()
- self.data = numpy.ascontiguousarray(data, dtype=numpy.float32)
+ self.data = numpy.ascontiguousarray(data, dtype=data_d)
+ if mask is not None:
+ self.mask = numpy.ascontiguousarray(mask, dtype=mask_d)
+ self.has_mask=True
+ else:
+ self.mask = None
+ self.has_mask = False
+ self._init_min_max()
def __dealloc__(self):
self.data = None
+ self.mask = None
def __call__(self, coord):
"""Function f((y, x)) where f is a continuous function
@@ -72,10 +102,24 @@ cdef class BilinearImage:
:return: Interpolated signal from the image
"""
return self.c_funct(coord[1], coord[0])
-
- @cython.boundscheck(False)
- @cython.wraparound(False)
- cdef float c_funct(self, float x, float y) nogil:
+
+ cdef void _init_min_max(self) nogil:
+ "Calculate the min & max"
+ cdef:
+ Py_ssize_t i, j
+ data_t maxi, mini, value
+ mini = FLT_MAX
+ maxi = -FLT_MAX
+ for i in range(self.height):
+ for j in range(self.width):
+ if not (self.has_mask and self.mask[i,j]):
+ value = self.data[i, j]
+ maxi = max(value, maxi)
+ mini = min(value, mini)
+ self.maxi = maxi
+ self.mini = mini
+
+ cdef data_t c_funct(self, data_t x, data_t y) nogil:
"""Function f(x, y) where f is a continuous function
made from the image.
@@ -86,10 +130,11 @@ cdef class BilinearImage:
Cython only function due to NOGIL
"""
cdef:
- float d0 = min(max(y, 0.0), (self.height - 1.0))
- float d1 = min(max(x, 0.0), (self.width - 1.0))
- int i0, i1, j0, j1
- float x0, x1, y0, y1, res
+ data_t d0 = min(max(y, 0.0), (self.height - 1.0))
+ data_t d1 = min(max(x, 0.0), (self.width - 1.0))
+ mask_t m0, m1, m2, m3
+ Py_ssize_t i0, i1, j0, j1
+ data_t x0, x1, y0, y1, res, scale
x0 = floor(d0)
x1 = ceil(d0)
@@ -100,20 +145,74 @@ cdef class BilinearImage:
j0 = < int > y0
j1 = < int > y1
if (i0 == i1) and (j0 == j1):
- res = self.data[i0, j0]
+ if not (self.has_mask and self.mask[i0,j0]):
+ res = self.data[i0, j0]
+ else:
+ res = NAN
elif i0 == i1:
- res = (self.data[i0, j0] * (y1 - d1)) + (self.data[i0, j1] * (d1 - y0))
+ if self.has_mask:
+ m0 = self.mask[i0, j0]
+ m1 = self.mask[i0, j1]
+ if m0 and m1:
+ res = NAN
+ elif m0:
+ res = self.data[i0, j1]
+ elif m1:
+ res = self.data[i0, j0]
+ else:
+ res = (self.data[i0, j0] * (y1 - d1)) + (self.data[i0, j1] * (d1 - y0))
+ else:
+ res = (self.data[i0, j0] * (y1 - d1)) + (self.data[i0, j1] * (d1 - y0))
elif j0 == j1:
- res = (self.data[i0, j0] * (x1 - d0)) + (self.data[i1, j0] * (d0 - x0))
+ if self.has_mask:
+ m0 = self.mask[i0, j0]
+ m1 = self.mask[i1, j0]
+ if m0 and m1:
+ res = NAN
+ elif m0:
+ res = self.data[i1, j0]
+ elif m1:
+ res = self.data[i0, j0]
+ else:
+ res = (self.data[i0, j0] * (x1 - d0)) + (self.data[i1, j0] * (d0 - x0))
+ else:
+ res = (self.data[i0, j0] * (x1 - d0)) + (self.data[i1, j0] * (d0 - x0))
else:
- res = (self.data[i0, j0] * (x1 - d0) * (y1 - d1)) \
- + (self.data[i1, j0] * (d0 - x0) * (y1 - d1)) \
- + (self.data[i0, j1] * (x1 - d0) * (d1 - y0)) \
- + (self.data[i1, j1] * (d0 - x0) * (d1 - y0))
+ if self.has_mask:
+ m0 = self.mask[i0, j0]
+ m1 = self.mask[i1, j0]
+ m2 = self.mask[i0, j1]
+ m3 = self.mask[i1, j1]
+ if m0 and m1 and m2 and m3:
+ res = NAN
+ else:
+ m0 = not m0
+ m1 = not m1
+ m2 = not m2
+ m3 = not m3
+ if m0 and m1 and m2 and m3:
+ res = (self.data[i0, j0] * (x1 - d0) * (y1 - d1)) \
+ + (self.data[i1, j0] * (d0 - x0) * (y1 - d1)) \
+ + (self.data[i0, j1] * (x1 - d0) * (d1 - y0)) \
+ + (self.data[i1, j1] * (d0 - x0) * (d1 - y0))
+ else:
+ res = (m0 * self.data[i0, j0] * (x1 - d0) * (y1 - d1)) \
+ + (m1 * self.data[i1, j0] * (d0 - x0) * (y1 - d1)) \
+ + (m2 * self.data[i0, j1] * (x1 - d0) * (d1 - y0)) \
+ + (m3 * self.data[i1, j1] * (d0 - x0) * (d1 - y0))
+ scale = ((m0 * (x1 - d0) * (y1 - d1))
+ + (m1 * (d0 - x0) * (y1 - d1))
+ + (m2 * (x1 - d0) * (d1 - y0))
+ + (m3 * (d0 - x0) * (d1 - y0)))
+ res /= scale
+ else:
+ res = (self.data[i0, j0] * (x1 - d0) * (y1 - d1)) \
+ + (self.data[i1, j0] * (d0 - x0) * (y1 - d1)) \
+ + (self.data[i0, j1] * (x1 - d0) * (d1 - y0)) \
+ + (self.data[i1, j1] * (d0 - x0) * (d1 - y0))
+
return res
- @cython.boundscheck(False)
- @cython.wraparound(False)
def opp_f(self, coord):
"""Function -f((y,x)) for peak finding via minimizer.
@@ -123,7 +222,7 @@ cdef class BilinearImage:
:return: Negative interpolated signal from the image
"""
cdef:
- float d0, d1, res
+ data_t d0, d1, res
d0, d1 = coord
if d0 < 0:
res = self.mini + d0
@@ -137,9 +236,6 @@ cdef class BilinearImage:
res = self.c_funct(d1, d0)
return - res
- @cython.boundscheck(False)
- @cython.wraparound(False)
- @cython.cdivision(True)
def local_maxi(self, coord):
"""Return the nearest local maximum ... with sub-pixel refinement
@@ -158,9 +254,9 @@ cdef class BilinearImage:
cdef:
int res, current0, current1
int i0, i1
- float tmp, sum0 = 0, sum1 = 0, sum = 0
- float a00, a01, a02, a10, a11, a12, a20, a21, a22
- float d00, d11, d01, denom, delta0, delta1
+ data_t tmp, sum0 = 0, sum1 = 0, sum = 0
+ data_t a00, a01, a02, a10, a11, a12, a20, a21, a22
+ data_t d00, d11, d01, denom, delta0, delta1
res = self.c_local_maxi(round(coord[0]) * self.width + round(coord[1]))
current0 = res // self.width
@@ -202,7 +298,7 @@ cdef class BilinearImage:
return (float(current0), float(current1))
- cpdef size_t coarse_local_maxi(self, size_t x):
+ cpdef Py_ssize_t coarse_local_maxi(self, Py_ssize_t x):
"""Return the nearest local maximum ... without sub-pixel refinement
:param idx: start index (=row*width+column)
@@ -210,10 +306,7 @@ cdef class BilinearImage:
"""
return self.c_local_maxi(x)
- @cython.boundscheck(False)
- @cython.wraparound(False)
- @cython.cdivision(True)
- cdef size_t c_local_maxi(self, size_t idx) nogil:
+ cdef Py_ssize_t c_local_maxi(self, Py_ssize_t idx) nogil:
"""Return the nearest local maximum without sub-pixel refinement
:param idx: start index (=row*width+column)
@@ -222,13 +315,39 @@ cdef class BilinearImage:
This method is Cython only due to the NOGIL
"""
cdef:
- int current0 = idx // self.width
- int current1 = idx % self.width
- int i0, i1, start0, stop0, start1, stop1, new0, new1
- float tmp, value, old_value
-
- value = self.data[current0, current1]
- old_value = value - 1.0
+ Py_ssize_t current0 = idx // self.width
+ Py_ssize_t current1 = idx % self.width
+ Py_ssize_t i0, i1, start0, stop0, start1, stop1, new0, new1, rng, cnt
+ mask_t m
+ data_t tmp, value, old_value
+
+ if self.has_mask and self.mask[current0, current1]:
+ #Start searching for a non masked pixel.
+ rng = 0
+ cnt = 0
+ value = self.mini
+ new0, new1 = current0, current1
+ while cnt == 0:
+ rng += 1
+ cnt = 0
+ start0 = max(0, current0 - rng)
+ stop0 = min(self.height, current0 + rng + 1)
+ start1 = max(0, current1 - rng)
+ stop1 = min(self.width, current1 + rng + 1)
+ for i0 in range(start0, stop0):
+ for i1 in range(start1, stop1):
+ m = not self.mask[i0, i1]
+ cnt += m
+ if m:
+ tmp = self.data[i0, i1]
+ if tmp > value:
+ new0, new1 = i0, i1
+ value = tmp
+ current0, current1 = new0, new1
+ else:
+ value = self.data[current0, current1]
+
+ old_value = value -1
new0, new1 = current0, current1
while value > old_value:
@@ -239,6 +358,8 @@ cdef class BilinearImage:
stop1 = min(self.width, current1 + 2)
for i0 in range(start0, stop0):
for i1 in range(start1, stop1):
+ if self.has_mask and self.mask[current0, current1]:
+ continue
tmp = self.data[i0, i1]
if tmp > value:
new0, new1 = i0, i1
@@ -246,7 +367,6 @@ cdef class BilinearImage:
current0, current1 = new0, new1
return self.width * current0 + current1
- @cython.boundscheck(False)
def map_coordinates(self, coordinates):
"""Map coordinates of the array on the image
@@ -254,20 +374,19 @@ cdef class BilinearImage:
:return: array of values at given coordinates
"""
cdef:
- float[:] d0, d1, res
- size_t size, i
+ data_t[:] d0, d1, res
+ Py_ssize_t size, i
shape = coordinates[0].shape
size = coordinates[0].size
- d0 = numpy.ascontiguousarray(coordinates[0].ravel(), dtype=numpy.float32)
- d1 = numpy.ascontiguousarray(coordinates[1].ravel(), dtype=numpy.float32)
+ d0 = numpy.ascontiguousarray(coordinates[0].ravel(), dtype=data_d)
+ d1 = numpy.ascontiguousarray(coordinates[1].ravel(), dtype=data_d)
assert size == d1.size
- res = numpy.empty(size, dtype=numpy.float32)
+ res = numpy.empty(size, dtype=data_d)
with nogil:
for i in range(size):
res[i] = self.c_funct(d1[i], d0[i])
return numpy.asarray(res).reshape(shape)
- @cython.boundscheck(False)
def profile_line(self, src, dst, int linewidth=1, method='mean'):
"""Return the mean or sum of intensity profile of an image measured
along a scan line.
@@ -289,10 +408,11 @@ cdef class BilinearImage:
Inspired from skimage
"""
cdef:
- float src_row, src_col, dst_row, dst_col, d_row, d_col
- float length, col_width, row_width, sum, row, col, new_row, new_col
- int lengt, i, j, cnt
- float[::1] result
+ data_t src_row, src_col, dst_row, dst_col, d_row, d_col
+ data_t length, col_width, row_width, sum, row, col, new_row, new_col, val
+ Py_ssize_t lengt, i, j, cnt
+ bint compute_mean
+ data_t[::1] result
src_row, src_col = src
dst_row, dst_col = dst
if (src_row == dst_row) and (src_col == dst_col):
@@ -307,10 +427,10 @@ cdef class BilinearImage:
col_width = - d_row / length
lengt = <int> ceil(length + 1)
- d_row /= <float> (lengt -1)
- d_col /= <float> (lengt -1)
+ d_row /= <data_t> (lengt -1)
+ d_col /= <data_t> (lengt -1)
- result = numpy.zeros(lengt, dtype=numpy.float32)
+ result = numpy.zeros(lengt, dtype=data_d)
# Offset position to the center of the bottom pixels of the profile
src_row -= row_width * (linewidth - 1) / 2.
@@ -330,13 +450,16 @@ cdef class BilinearImage:
new_col = col + j * col_width
if ((new_col >= 0) and (new_col < self.width) and
(new_row >= 0) and (new_row < self.height)):
- cnt = cnt + 1
- sum = sum + self.c_funct(new_col, new_row)
+ val = self.c_funct(new_col, new_row)
+ if isfinite(val):
+ cnt += 1
+ sum += val
if cnt:
- if compute_mean is True:
- result[i] += sum / cnt
- else:
- result[i] += sum
-
+ if compute_mean:
+ result[i] += sum / cnt
+ else:
+ result[i] += sum
+ elif compute_mean:
+ result[i] += NAN
# Ensures the result is exported as numpy array and not memory view.
return numpy.asarray(result)
diff --git a/silx/image/test/test_bilinear.py b/silx/image/test/test_bilinear.py
index 12d0067..55eaccb 100644
--- a/silx/image/test/test_bilinear.py
+++ b/silx/image/test/test_bilinear.py
@@ -24,7 +24,7 @@
__authors__ = ["J. Kieffer"]
__license__ = "MIT"
-__date__ = "02/08/2016"
+__date__ = "25/11/2020"
import unittest
import numpy
@@ -45,6 +45,10 @@ class TestBilinear(unittest.TestCase):
gb = numpy.exp(-b * b / 6000)
gg = numpy.outer(ga, gb)
b = BilinearImage(gg)
+
+ self.assertAlmostEqual(b.maxi, 1, 2, "maxi is almost 1")
+ self.assertLess(b.mini, 0.3, "mini should be around 0.23")
+
ok = 0
for s in range(self.N):
i, j = numpy.random.randint(100), numpy.random.randint(100)
@@ -78,8 +82,8 @@ class TestBilinear(unittest.TestCase):
self.assertEqual(ok, self.N, "Maximum is always found")
def test_map(self):
- N = 100
- y, x = numpy.ogrid[:N, :N + 10]
+ N = 6
+ y, x = numpy.ogrid[:N,:N + 10]
img = x + y
b = BilinearImage(img)
x2d = numpy.zeros_like(y) + x
@@ -87,16 +91,47 @@ class TestBilinear(unittest.TestCase):
res1 = b.map_coordinates((y2d, x2d))
self.assertEqual(abs(res1 - img).max(), 0, "images are the same (corners)")
- x2d = numpy.zeros_like(y) + (x[:, :-1] + 0.5)
- y2d = numpy.zeros_like(x[:, :-1]) + y
+ x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + y
res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img[:, :-1] - 0.5).max(), 0, "images are the same (middle)")
+ self.assertEqual(abs(res1 - img[:,:-1] - 0.5).max(), 0, "images are the same (middle)")
- x2d = numpy.zeros_like(y[:-1, :]) + (x[:, :-1] + 0.5)
- y2d = numpy.zeros_like(x[:, :-1]) + (y[:-1, :] + 0.5)
+ x2d = numpy.zeros_like(y[:-1,:]) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1,:] + 0.5)
res1 = b.map_coordinates((y2d, x2d))
self.assertEqual(abs(res1 - img[:-1, 1:]).max(), 0, "images are the same (center)")
+ def test_mask_grad(self):
+ N = 100
+ img = numpy.arange(N * N).reshape(N, N)
+ # No mask on the boundaries, makes the test complicated, pixel always separated
+ masked = 2 * numpy.random.randint(0, int((N - 1) / 2), size=(2, N)) + 1
+ mask = numpy.zeros((N, N), dtype=numpy.uint8)
+ mask[(masked[0], masked[1])] = 1
+ self.assertLessEqual(mask.sum(), N, "At most N pixels are masked")
+
+ b = BilinearImage(img, mask=mask)
+ self.assertEqual(b.has_mask, True, "interpolator has mask")
+ self.assertEqual(b.maxi, N * N - 1, "maxi is N²-1")
+ self.assertEqual(b.mini, 0, "mini is 0")
+
+ y, x = numpy.ogrid[:N,:N]
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertEqual(numpy.nanmax(abs(res1 - img)), 0, "images are the same (corners), or Nan ")
+
+ x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertLessEqual(numpy.max(abs(res1 - img[:, 1:] + 1 / 2.)), 0.5, "images are the same (middle) +/- 0.5")
+
+ x2d = numpy.zeros_like(y[:-1]) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1] + 0.5)
+ res1 = b.map_coordinates((y2d, x2d))
+ exp = 0.25 * (img[:-1,:-1] + img[:-1, 1:] + img[1:,:-1] + img[1:, 1:])
+ self.assertLessEqual(abs(res1 - exp).max(), N / 4, "images are almost the same (center)")
+
def test_profile_grad(self):
N = 100
img = numpy.arange(N * N).reshape(N, N)
@@ -139,4 +174,5 @@ def suite():
testsuite.addTest(TestBilinear("test_map"))
testsuite.addTest(TestBilinear("test_profile_grad"))
testsuite.addTest(TestBilinear("test_profile_gaus"))
+ testsuite.addTest(TestBilinear("test_mask_grad"))
return testsuite
diff --git a/silx/io/dictdump.py b/silx/io/dictdump.py
index bbb244a..e907668 100644
--- a/silx/io/dictdump.py
+++ b/silx/io/dictdump.py
@@ -26,6 +26,7 @@ by text strings to following file formats: `HDF5, INI, JSON`
"""
from collections import OrderedDict
+from collections.abc import Mapping
import json
import logging
import numpy
@@ -34,11 +35,16 @@ import sys
import h5py
from .configdict import ConfigDict
-from .utils import is_group, is_link, is_softlink, is_externallink
+from .utils import is_group
+from .utils import is_dataset
+from .utils import is_link
+from .utils import is_softlink
+from .utils import is_externallink
from .utils import is_file as is_h5_file_like
from .utils import open as h5open
from .utils import h5py_read_dataset
from .utils import H5pyAttributesReadWrapper
+from silx.utils.deprecation import deprecated_warning
__authors__ = ["P. Knobel"]
__license__ = "MIT"
@@ -66,7 +72,7 @@ def _prepare_hdf5_write_value(array_like):
return array
-class _SafeH5FileWrite(object):
+class _SafeH5FileWrite:
"""Context manager returning a :class:`h5py.File` object.
If this object is initialized with a file path, we open the file
@@ -82,7 +88,6 @@ class _SafeH5FileWrite(object):
"""
def __init__(self, h5file, mode="w"):
"""
-
:param h5file: HDF5 file path or :class:`h5py.File` instance
:param str mode: Can be ``"r+"`` (read/write, file must exist),
``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
@@ -106,7 +111,7 @@ class _SafeH5FileWrite(object):
self.h5file.close()
-class _SafeH5FileRead(object):
+class _SafeH5FileRead:
"""Context manager returning a :class:`h5py.File` or a
:class:`silx.io.spech5.SpecH5` or a :class:`silx.io.fabioh5.File` object.
@@ -136,18 +141,48 @@ class _SafeH5FileRead(object):
self.h5file.close()
+def _normalize_h5_path(h5root, h5path):
+ """
+ :param h5root: File name or h5py-like File, Group or Dataset
+ :param str h5path: relative to ``h5root``
+ :returns 2-tuple: (File or file object, h5path)
+ """
+ if is_group(h5root):
+ group_name = h5root.name
+ if group_name == "/":
+ pass
+ elif h5path:
+ h5path = group_name + "/" + h5path
+ else:
+ h5path = group_name
+ h5file = h5root.file
+ elif is_dataset(h5root):
+ h5path = h5root.name
+ h5file = h5root.file
+ else:
+ h5file = h5root
+ if not h5path:
+ h5path = "/"
+ elif not h5path.endswith("/"):
+ h5path += "/"
+ return h5file, h5path
+
+
def dicttoh5(treedict, h5file, h5path='/',
- mode="w", overwrite_data=False,
- create_dataset_args=None):
+ mode="w", overwrite_data=None,
+ create_dataset_args=None, update_mode=None):
"""Write a nested dictionary to a HDF5 file, using keys as member names.
If a dictionary value is a sub-dictionary, a group is created. If it is
any other data type, it is cast into a numpy array and written as a
:mod:`h5py` dataset. Dictionary keys must be strings and cannot contain
the ``/`` character.
-
+
If dictionary keys are tuples they are interpreted to set h5 attributes.
- The tuples should have the format (dataset_name,attr_name)
+ The tuples should have the format (dataset_name, attr_name).
+
+ Existing HDF5 items can be deleted by providing the dictionary value
+ ``None``, provided that ``update_mode in ["modify", "replace"]``.
.. note::
@@ -158,21 +193,29 @@ def dicttoh5(treedict, h5file, h5path='/',
to define sub trees. If tuples are used as keys they should have the
format (dataset_name,attr_name) and will add a 5h attribute with the
corresponding value.
- :param h5file: HDF5 file name or handle. If a file name is provided, the
- function opens the file in the specified mode and closes it again
- before completing.
- :param h5path: Target path in HDF5 file in which scan groups are created.
+ :param h5file: File name or h5py-like File, Group or Dataset
+ :param h5path: Target path in the HDF5 file relative to ``h5file``.
Default is root (``"/"``)
:param mode: Can be ``"r+"`` (read/write, file must exist),
``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
exists) or ``"a"`` (read/write if exists, create otherwise).
This parameter is ignored if ``h5file`` is a file handle.
- :param overwrite_data: If ``True``, existing groups and datasets can be
- overwritten, if ``False`` they are skipped. This parameter is only
- relevant if ``h5file_mode`` is ``"r+"`` or ``"a"``.
+ :param overwrite_data: Deprecated. ``True`` is approximately equivalent
+ to ``update_mode="modify"`` and ``False`` is equivalent to
+ ``update_mode="add"``.
:param create_dataset_args: Dictionary of args you want to pass to
``h5f.create_dataset``. This allows you to specify filters and
compression parameters. Don't specify ``name`` and ``data``.
+ :param update_mode: Can be ``add`` (default), ``modify`` or ``replace``.
+
+ * ``add``: Extend the existing HDF5 tree when possible. Existing HDF5
+ items (groups, datasets and attributes) remain untouched.
+ * ``modify``: Extend the existing HDF5 tree when possible, modify
+ existing attributes, modify same-sized dataset values and delete
+ HDF5 items with a ``None`` value in the dict tree.
+ * ``replace``: Replace the existing HDF5 tree. Items from the root of
+ the HDF5 tree that are not present in the root of the dict tree
+ will remain untouched.
Example::
@@ -201,44 +244,110 @@ def dicttoh5(treedict, h5file, h5path='/',
create_dataset_args=create_ds_args)
"""
- if not h5path.endswith("/"):
- h5path += "/"
+ if overwrite_data is not None:
+ reason = (
+ "`overwrite_data=True` becomes `update_mode='modify'` and "
+ "`overwrite_data=False` becomes `update_mode='add'`"
+ )
+ deprecated_warning(
+ type_="argument",
+ name="overwrite_data",
+ reason=reason,
+ replacement="update_mode",
+ since_version="0.15",
+ )
+
+ if update_mode is None:
+ if overwrite_data:
+ update_mode = "modify"
+ else:
+ update_mode = "add"
+ else:
+ valid_existing_values = ("add", "replace", "modify")
+ if update_mode not in valid_existing_values:
+ raise ValueError((
+ "Argument 'update_mode' can only have values: {}"
+ "".format(valid_existing_values)
+ ))
+ if overwrite_data is not None:
+ logger.warning("The argument `overwrite_data` is ignored")
- with _SafeH5FileWrite(h5file, mode=mode) as h5f:
- if isinstance(treedict, dict) and h5path != "/":
- if h5path not in h5f:
- h5f.create_group(h5path)
+ if not isinstance(treedict, Mapping):
+ raise TypeError("'treedict' must be a dictionary")
- for key in filter(lambda k: not isinstance(k, tuple), treedict):
- key_is_group = isinstance(treedict[key], dict)
- h5name = h5path + key
+ h5file, h5path = _normalize_h5_path(h5file, h5path)
- if key_is_group and treedict[key]:
- # non-empty group: recurse
- dicttoh5(treedict[key], h5f, h5name,
- overwrite_data=overwrite_data,
- create_dataset_args=create_dataset_args)
- continue
+ def _iter_treedict(attributes=False):
+ nonlocal treedict
+ for key, value in treedict.items():
+ if isinstance(key, tuple) == attributes:
+ yield key, value
- if h5name in h5f:
- # key already exists: delete or skip
- if overwrite_data is True:
- del h5f[h5name]
+ change_allowed = update_mode in ("replace", "modify")
+
+ with _SafeH5FileWrite(h5file, mode=mode) as h5f:
+ # Create the root of the tree
+ if h5path in h5f:
+ if not is_group(h5f[h5path]):
+ if update_mode == "replace":
+ del h5f[h5path]
+ h5f.create_group(h5path)
else:
- logger.warning('key (%s) already exists. '
- 'Not overwriting.' % (h5name))
- continue
+ return
+ else:
+ h5f.create_group(h5path)
- value = treedict[key]
+ # Loop over all groups, links and datasets
+ for key, value in _iter_treedict(attributes=False):
+ h5name = h5path + key
+ exists = h5name in h5f
- if value is None or key_is_group:
- # Create empty group
- h5f.create_group(h5name)
+ if value is None:
+ # Delete HDF5 item
+ if exists and change_allowed:
+ del h5f[h5name]
+ exists = False
+ elif isinstance(value, Mapping):
+ # HDF5 group
+ if exists and update_mode == "replace":
+ del h5f[h5name]
+ exists = False
+ if value:
+ dicttoh5(value, h5f, h5name,
+ update_mode=update_mode,
+ create_dataset_args=create_dataset_args)
+ elif not exists:
+ h5f.create_group(h5name)
elif is_link(value):
- h5f[h5name] = value
+ # HDF5 link
+ if exists and update_mode == "replace":
+ del h5f[h5name]
+ exists = False
+ if not exists:
+ # Create link from h5py link object
+ h5f[h5name] = value
else:
+ # HDF5 dataset
+ if exists and not change_allowed:
+ continue
data = _prepare_hdf5_write_value(value)
- # can't apply filters on scalars (datasets with shape == () )
+
+ # Edit the existing dataset
+ attrs_backup = None
+ if exists:
+ try:
+ h5f[h5name][()] = data
+ continue
+ except Exception:
+ # Delete the existing dataset
+ if update_mode != "replace":
+ if not is_dataset(h5f[h5name]):
+ continue
+ attrs_backup = dict(h5f[h5name].attrs)
+ del h5f[h5name]
+
+ # Create dataset
+ # can't apply filters on scalars (datasets with shape == ())
if data.shape == () or create_dataset_args is None:
h5f.create_dataset(h5name,
data=data)
@@ -246,36 +355,58 @@ def dicttoh5(treedict, h5file, h5path='/',
h5f.create_dataset(h5name,
data=data,
**create_dataset_args)
+ if attrs_backup:
+ h5f[h5name].attrs.update(attrs_backup)
- # deal with h5 attributes which have tuples as keys in treedict
- for key in filter(lambda k: isinstance(k, tuple), treedict):
- assert len(key) == 2, "attribute must be defined by 2 values"
+ # Loop over all attributes
+ for key, value in _iter_treedict(attributes=True):
+ if len(key) != 2:
+ raise ValueError("HDF5 attribute must be described by 2 values")
h5name = h5path + key[0]
attr_name = key[1]
if h5name not in h5f:
- # Create empty group if key for attr does not exist
+ # Create an empty group to store the attribute
h5f.create_group(h5name)
- logger.warning(
- "key (%s) does not exist. attr %s "
- "will be written to ." % (h5name, attr_name)
- )
-
- if attr_name in h5f[h5name].attrs:
- if not overwrite_data:
- logger.warning(
- "attribute %s@%s already exists. Not overwriting."
- "" % (h5name, attr_name)
- )
+
+ h5a = h5f[h5name].attrs
+ exists = attr_name in h5a
+
+ if value is None:
+ # Delete HDF5 attribute
+ if exists and change_allowed:
+ del h5a[attr_name]
+ exists = False
+ else:
+ # Add/modify HDF5 attribute
+ if exists and not change_allowed:
continue
+ data = _prepare_hdf5_write_value(value)
+ h5a[attr_name] = data
- # Write attribute
- value = treedict[key]
- data = _prepare_hdf5_write_value(value)
- h5f[h5name].attrs[attr_name] = data
+
+def _has_nx_class(treedict, key=""):
+ return key + "@NX_class" in treedict or \
+ (key, "NX_class") in treedict
+
+
+def _ensure_nx_class(treedict, parents=tuple()):
+ """Each group needs an "NX_class" attribute.
+ """
+ if _has_nx_class(treedict):
+ return
+ nparents = len(parents)
+ if nparents == 0:
+ treedict[("", "NX_class")] = "NXroot"
+ elif nparents == 1:
+ treedict[("", "NX_class")] = "NXentry"
+ else:
+ treedict[("", "NX_class")] = "NXcollection"
-def nexus_to_h5_dict(treedict, parents=tuple()):
+def nexus_to_h5_dict(
+ treedict, parents=tuple(), add_nx_class=True, has_nx_class=False
+):
"""The following conversions are applied:
* key with "{name}@{attr_name}" notation: key converted to 2-tuple
* key with ">{url}" notation: strip ">" and convert value to
@@ -286,14 +417,20 @@ def nexus_to_h5_dict(treedict, parents=tuple()):
to define sub tree. The ``"@"`` character is used to write attributes.
The ``">"`` prefix is used to define links.
:param parents: Needed to resolve up-links (tuple of HDF5 group names)
+ :param add_nx_class: Add "NX_class" attribute when missing
+ :param has_nx_class: The "NX_class" attribute is defined in the parent
:rtype dict:
"""
+ if not isinstance(treedict, Mapping):
+ raise TypeError("'treedict' must be a dictionary")
copy = dict()
for key, value in treedict.items():
if "@" in key:
+ # HDF5 attribute
key = tuple(key.rsplit("@", 1))
elif key.startswith(">"):
+ # HDF5 link
if isinstance(value, str):
key = key[1:]
first, sep, second = value.partition("::")
@@ -314,10 +451,19 @@ def nexus_to_h5_dict(treedict, parents=tuple()):
value = h5py.SoftLink(first)
elif is_link(value):
key = key[1:]
- if isinstance(value, dict):
- copy[key] = nexus_to_h5_dict(value, parents=parents+(key,))
+ if isinstance(value, Mapping):
+ # HDF5 group
+ key_has_nx_class = add_nx_class and _has_nx_class(treedict, key)
+ copy[key] = nexus_to_h5_dict(
+ value,
+ parents=parents+(key,),
+ add_nx_class=add_nx_class,
+ has_nx_class=key_has_nx_class)
else:
+ # HDF5 dataset or link
copy[key] = value
+ if add_nx_class and not has_nx_class:
+ _ensure_nx_class(copy, parents)
return copy
@@ -336,7 +482,8 @@ def h5_to_nexus_dict(treedict):
copy = dict()
for key, value in treedict.items():
if isinstance(key, tuple):
- assert len(key)==2, "attribute must be defined by 2 values"
+ if len(key) != 2:
+ raise ValueError("HDF5 attribute must be described by 2 values")
key = "%s@%s" % (key[0], key[1])
elif is_softlink(value):
key = ">" + key
@@ -344,7 +491,7 @@ def h5_to_nexus_dict(treedict):
elif is_externallink(value):
key = ">" + key
value = value.filename + "::" + value.path
- if isinstance(value, dict):
+ if isinstance(value, Mapping):
copy[key] = h5_to_nexus_dict(value)
else:
copy[key] = value
@@ -414,10 +561,8 @@ def h5todict(h5file,
scalars). In some cases, you may find that a list of heterogeneous
data types is converted to a numpy array of strings.
- :param h5file: File name or :class:`h5py.File` object or spech5 file or
- fabioh5 file.
- :param str path: Name of HDF5 group to use as dictionary root level,
- to read only a sub-group in the file
+ :param h5file: File name or h5py-like File, Group or Dataset
+ :param str path: Target path in the HDF5 file relative to ``h5file``
:param List[str] exclude_names: Groups and datasets whose name contains
a string in this list will be ignored. Default is None (ignore nothing)
:param bool asarray: True (default) to read scalar as arrays, False to
@@ -431,6 +576,7 @@ def h5todict(h5file,
- 'ignore': Ignore errors
:return: Nested dictionary
"""
+ h5file, path = _normalize_h5_path(h5file, path)
with _SafeH5FileRead(h5file) as h5f:
ddict = {}
if path not in h5f:
@@ -508,7 +654,7 @@ def h5todict(h5file,
return ddict
-def dicttonx(treedict, h5file, h5path="/", **kw):
+def dicttonx(treedict, h5file, h5path="/", add_nx_class=None, **kw):
"""
Write a nested dictionary to a HDF5 file, using string keys as member names.
The NeXus convention is used to identify attributes with ``"@"`` character,
@@ -521,6 +667,8 @@ def dicttonx(treedict, h5file, h5path="/", **kw):
and array-like objects as leafs. The ``"/"`` character can be used
to define sub tree. The ``"@"`` character is used to write attributes.
The ``">"`` prefix is used to define links.
+ :param add_nx_class: Add "NX_class" attribute when missing. By default it
+ is ``True`` when ``update_mode`` is ``"add"`` or ``None``.
The named parameters are passed to dicttoh5.
@@ -557,12 +705,17 @@ def dicttonx(treedict, h5file, h5path="/", **kw):
dicttonx(gauss,"test.h5")
"""
+ h5file, h5path = _normalize_h5_path(h5file, h5path)
parents = tuple(p for p in h5path.split("/") if p)
- nxtreedict = nexus_to_h5_dict(treedict, parents=parents)
+ if add_nx_class is None:
+ add_nx_class = kw.get("update_mode", None) in (None, "add")
+ nxtreedict = nexus_to_h5_dict(
+ treedict, parents=parents, add_nx_class=add_nx_class
+ )
dicttoh5(nxtreedict, h5file, h5path=h5path, **kw)
-def nxtodict(h5file, **kw):
+def nxtodict(h5file, include_attributes=True, **kw):
"""Read a HDF5 file and return a nested dictionary with the complete file
structure and all data.
@@ -571,7 +724,7 @@ def nxtodict(h5file, **kw):
The named parameters are passed to h5todict.
"""
- nxtreedict = h5todict(h5file, **kw)
+ nxtreedict = h5todict(h5file, include_attributes=include_attributes, **kw)
return h5_to_nexus_dict(nxtreedict)
diff --git a/silx/io/h5py_utils.py b/silx/io/h5py_utils.py
new file mode 100644
index 0000000..cbdb44a
--- /dev/null
+++ b/silx/io/h5py_utils.py
@@ -0,0 +1,317 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides utility methods on top of h5py, mainly to handle
+parallel writing and reading.
+"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "27/01/2020"
+
+
+import os
+import traceback
+import h5py
+
+from .._version import calc_hexversion
+from ..utils import retry as retry_mod
+
+H5PY_HEX_VERSION = calc_hexversion(*h5py.version.version_tuple[:3])
+HDF5_HEX_VERSION = calc_hexversion(*h5py.version.hdf5_version_tuple[:3])
+
+HDF5_SWMR_VERSION = calc_hexversion(*h5py.get_config().swmr_min_hdf5_version[:3])
+HDF5_TRACK_ORDER_VERSION = calc_hexversion(2, 9, 0)
+
+HAS_SWMR = HDF5_HEX_VERSION >= HDF5_SWMR_VERSION
+HAS_TRACK_ORDER = H5PY_HEX_VERSION >= HDF5_TRACK_ORDER_VERSION
+
+
+def _is_h5py_exception(e):
+ for frame in traceback.walk_tb(e.__traceback__):
+ if frame[0].f_locals.get("__package__", None) == "h5py":
+ return True
+ return False
+
+
+def _retry_h5py_error(e):
+ """
+ :param BaseException e:
+ :returns bool:
+ """
+ if _is_h5py_exception(e):
+ if isinstance(e, (OSError, RuntimeError)):
+ return True
+ elif isinstance(e, KeyError):
+ # For example this needs to be retried:
+ # KeyError: 'Unable to open object (bad object header version number)'
+ return "Unable to open object" in str(e)
+ elif isinstance(e, retry_mod.RetryError):
+ return True
+ return False
+
+
+def retry(**kw):
+ """Decorator for a method that needs to be executed until it not longer
+ fails on HDF5 IO. Mainly used for reading an HDF5 file that is being
+ written.
+
+ :param \**kw: see `silx.utils.retry`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry(**kw)
+
+
+def retry_contextmanager(**kw):
+ """Decorator to make a context manager from a method that needs to be
+ entered until it not longer fails on HDF5 IO. Mainly used for reading
+ an HDF5 file that is being written.
+
+ :param \**kw: see `silx.utils.retry_contextmanager`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry_contextmanager(**kw)
+
+
+def retry_in_subprocess(**kw):
+ """Same as `retry` but it also retries segmentation faults.
+
+ On Window you cannot use this decorator with the "@" syntax:
+
+ .. code-block:: python
+
+ def _method(*args, **kw):
+ ...
+
+ method = retry_in_subprocess()(_method)
+
+ :param \**kw: see `silx.utils.retry_in_subprocess`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry_in_subprocess(**kw)
+
+
+def group_has_end_time(h5item):
+ """Returns True when the HDF5 item is a Group with an "end_time"
+ dataset. A reader can use this as an indication that the Group
+ has been fully written (at least if the writer supports this).
+
+ :param Union[h5py.Group,h5py.Dataset] h5item:
+ :returns bool:
+ """
+ if isinstance(h5item, h5py.Group):
+ return "end_time" in h5item
+ else:
+ return False
+
+
+@retry_contextmanager()
+def open_item(filename, name, retry_invalid=False, validate=None):
+ """Yield an HDF5 dataset or group (retry until it can be instantiated).
+
+ :param str filename:
+ :param bool retry_invalid: retry when item is missing or not valid
+ :param callable or None validate:
+ :yields Dataset, Group or None:
+ """
+ with File(filename) as h5file:
+ try:
+ item = h5file[name]
+ except KeyError as e:
+ if "doesn't exist" in str(e):
+ if retry_invalid:
+ raise retry_mod.RetryError
+ else:
+ item = None
+ else:
+ raise
+ if callable(validate) and item is not None:
+ if not validate(item):
+ if retry_invalid:
+ raise retry_mod.RetryError
+ else:
+ item = None
+ yield item
+
+
+def _top_level_names(filename, include_only=group_has_end_time):
+ """Return all valid top-level HDF5 names.
+
+ :param str filename:
+ :param callable or None include_only:
+ :returns list(str):
+ """
+ with File(filename) as h5file:
+ try:
+ if callable(include_only):
+ return [name for name in h5file["/"] if include_only(h5file[name])]
+ else:
+ return list(h5file["/"])
+ except KeyError:
+ raise retry_mod.RetryError
+
+
+top_level_names = retry()(_top_level_names)
+safe_top_level_names = retry_in_subprocess()(_top_level_names)
+
+
+class File(h5py.File):
+ """Takes care of HDF5 file locking and SWMR mode without the need
+ to handle those explicitely.
+
+ When using this class, you cannot open different files simultatiously
+ with different modes because the locking flag is an environment variable.
+ """
+
+ _HDF5_FILE_LOCKING = None
+ _NOPEN = 0
+ _SWMR_LIBVER = "latest"
+
+ def __init__(
+ self,
+ filename,
+ mode=None,
+ enable_file_locking=None,
+ swmr=None,
+ libver=None,
+ **kwargs
+ ):
+ """The arguments `enable_file_locking` and `swmr` should not be
+ specified explicitly for normal use cases.
+
+ :param str filename:
+ :param str or None mode: read-only by default
+ :param bool or None enable_file_locking: by default it is disabled for `mode='r'`
+ and `swmr=False` and enabled for all
+ other modes.
+ :param bool or None swmr: try both modes when `mode='r'` and `swmr=None`
+ :param **kwargs: see `h5py.File.__init__`
+ """
+ if mode is None:
+ mode = "r"
+ elif mode not in ("r", "w", "w-", "x", "a", "r+"):
+ raise ValueError("invalid mode {}".format(mode))
+ if not HAS_SWMR:
+ swmr = False
+
+ if enable_file_locking is None:
+ enable_file_locking = bool(mode != "r" or swmr)
+ if self._NOPEN:
+ self._check_locking_env(enable_file_locking)
+ else:
+ self._set_locking_env(enable_file_locking)
+
+ if swmr and libver is None:
+ libver = self._SWMR_LIBVER
+
+ if HAS_TRACK_ORDER:
+ kwargs.setdefault("track_order", True)
+ try:
+ super().__init__(filename, mode=mode, swmr=swmr, libver=libver, **kwargs)
+ except OSError as e:
+ # wlock wSWMR rlock rSWMR OSError: Unable to open file (...)
+ # 1 TRUE FALSE FALSE FALSE -
+ # 2 TRUE FALSE FALSE TRUE -
+ # 3 TRUE FALSE TRUE FALSE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
+ # 4 TRUE FALSE TRUE TRUE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
+ # 5 TRUE TRUE FALSE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
+ # 6 TRUE TRUE FALSE TRUE -
+ # 7 TRUE TRUE TRUE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
+ # 8 TRUE TRUE TRUE TRUE -
+ if (
+ mode == "r"
+ and swmr is None
+ and "file is already open for write" in str(e)
+ ):
+ # Try reading in SWMR mode (situation 5 and 7)
+ swmr = True
+ if libver is None:
+ libver = self._SWMR_LIBVER
+ super().__init__(
+ filename, mode=mode, swmr=swmr, libver=libver, **kwargs
+ )
+ else:
+ raise
+ else:
+ self._add_nopen(1)
+ try:
+ if mode != "r" and swmr:
+ # Try setting writer in SWMR mode
+ self.swmr_mode = True
+ except Exception:
+ self.close()
+ raise
+
+ @classmethod
+ def _add_nopen(cls, v):
+ cls._NOPEN = max(cls._NOPEN + v, 0)
+
+ def close(self):
+ super().close()
+ self._add_nopen(-1)
+ if not self._NOPEN:
+ self._restore_locking_env()
+
+ def _set_locking_env(self, enable):
+ self._backup_locking_env()
+ if enable:
+ os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
+ elif enable is None:
+ try:
+ del os.environ["HDF5_USE_FILE_LOCKING"]
+ except KeyError:
+ pass
+ else:
+ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
+
+ def _get_locking_env(self):
+ v = os.environ.get("HDF5_USE_FILE_LOCKING")
+ if v == "TRUE":
+ return True
+ elif v is None:
+ return None
+ else:
+ return False
+
+ def _check_locking_env(self, enable):
+ if enable != self._get_locking_env():
+ if enable:
+ raise RuntimeError(
+ "Close all HDF5 files before enabling HDF5 file locking"
+ )
+ else:
+ raise RuntimeError(
+ "Close all HDF5 files before disabling HDF5 file locking"
+ )
+
+ def _backup_locking_env(self):
+ v = os.environ.get("HDF5_USE_FILE_LOCKING")
+ if v is None:
+ self._HDF5_FILE_LOCKING = None
+ else:
+ self._HDF5_FILE_LOCKING = v == "TRUE"
+
+ def _restore_locking_env(self):
+ self._set_locking_env(self._HDF5_FILE_LOCKING)
+ self._HDF5_FILE_LOCKING = None
diff --git a/silx/io/spech5.py b/silx/io/spech5.py
index 79fd2e4..1eaec7c 100644
--- a/silx/io/spech5.py
+++ b/silx/io/spech5.py
@@ -1,6 +1,6 @@
# coding: utf-8
# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -194,7 +194,7 @@ import numpy
import six
from silx import version as silx_version
-from .specfile import SpecFile
+from .specfile import SpecFile, SfErrColNotFound
from . import commonh5
__authors__ = ["P. Knobel", "D. Naudet"]
@@ -670,6 +670,10 @@ class PositionersGroup(commonh5.Group, SpecH5Group):
def __init__(self, parent, scan):
commonh5.Group.__init__(self, name="positioners", parent=parent,
attrs={"NX_class": to_h5py_utf8("NXcollection")})
+
+ dataset_info = [] # Store list of positioner's (name, value)
+ is_error = False # True if error encountered
+
for motor_name in scan.motor_names:
safe_motor_name = motor_name.replace("/", "%")
if motor_name in scan.labels and scan.data.shape[0] > 0:
@@ -678,10 +682,24 @@ class PositionersGroup(commonh5.Group, SpecH5Group):
else:
# Take value from #P scan header.
# (may return float("inf") if #P line is missing from scan hdr)
- motor_value = scan.motor_position_by_name(motor_name)
- self.add_node(SpecH5NodeDataset(name=safe_motor_name,
- data=motor_value,
- parent=self))
+ try:
+ motor_value = scan.motor_position_by_name(motor_name)
+ except SfErrColNotFound:
+ is_error = True
+ motor_value = float('inf')
+ dataset_info.append((safe_motor_name, motor_value))
+
+ if is_error: # Filter-out scalar values
+ logger1.warning("Mismatching number of elements in #P and #O: Ignoring")
+ dataset_info = [
+ (name, value) for name, value in dataset_info
+ if not isinstance(value, float)]
+
+ for name, value in dataset_info:
+ self.add_node(SpecH5NodeDataset(
+ name=name,
+ data=value,
+ parent=self))
class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
diff --git a/silx/io/test/__init__.py b/silx/io/test/__init__.py
index a309ee9..68b6e9b 100644
--- a/silx/io/test/__init__.py
+++ b/silx/io/test/__init__.py
@@ -40,6 +40,7 @@ from .test_nxdata import suite as test_nxdata_suite
from .test_commonh5 import suite as test_commonh5_suite
from .test_rawh5 import suite as test_rawh5_suite
from .test_url import suite as test_url_suite
+from .test_h5py_utils import suite as test_h5py_utils_suite
def suite():
@@ -56,4 +57,5 @@ def suite():
test_suite.addTest(test_commonh5_suite())
test_suite.addTest(test_rawh5_suite())
test_suite.addTest(test_url_suite())
+ test_suite.addTest(test_h5py_utils_suite())
return test_suite
diff --git a/silx/io/test/test_dictdump.py b/silx/io/test/test_dictdump.py
index b99116b..93c9183 100644
--- a/silx/io/test/test_dictdump.py
+++ b/silx/io/test/test_dictdump.py
@@ -33,6 +33,7 @@ import os
import tempfile
import unittest
import h5py
+from copy import deepcopy
from collections import defaultdict
@@ -72,7 +73,63 @@ link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset")
link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset")
-class TestDictToH5(unittest.TestCase):
+class DictTestCase(unittest.TestCase):
+
+ def assertRecursiveEqual(self, expected, actual, nodes=tuple()):
+ err_msg = "\n\n Tree nodes: {}".format(nodes)
+ if isinstance(expected, dict):
+ self.assertTrue(isinstance(actual, dict), msg=err_msg)
+ self.assertEqual(
+ set(expected.keys()),
+ set(actual.keys()),
+ msg=err_msg
+ )
+ for k in actual:
+ self.assertRecursiveEqual(
+ expected[k],
+ actual[k],
+ nodes=nodes + (k,),
+ )
+ return
+ if isinstance(actual, numpy.ndarray):
+ actual = actual.tolist()
+ if isinstance(expected, numpy.ndarray):
+ expected = expected.tolist()
+
+ self.assertEqual(expected, actual, msg=err_msg)
+
+
+class H5DictTestCase(DictTestCase):
+
+ def _dictRoundTripNormalize(self, treedict):
+ """Convert the dictionary as expected from a round-trip
+ treedict -> dicttoh5 -> h5todict -> newtreedict
+ """
+ for key, value in list(treedict.items()):
+ if isinstance(value, dict):
+ self._dictRoundTripNormalize(value)
+
+ # Expand treedict[("group", "attr_name")]
+ # to treedict["group"]["attr_name"]
+ for key, value in list(treedict.items()):
+ if not isinstance(key, tuple):
+ continue
+ # Put the attribute inside the group
+ grpname, attr = key
+ if not grpname:
+ continue
+ group = treedict.setdefault(grpname, dict())
+ if isinstance(group, dict):
+ del treedict[key]
+ group[("", attr)] = value
+
+ def dictRoundTripNormalize(self, treedict):
+ treedict2 = deepcopy(treedict)
+ self._dictRoundTripNormalize(treedict2)
+ return treedict2
+
+
+class TestDictToH5(H5DictTestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
@@ -110,14 +167,13 @@ class TestDictToH5(unittest.TestCase):
min(ddict["city attributes"]["Europe"]["France"]["Grenoble"]["coordinates"]),
5.7196)
- def testH5Overwrite(self):
+ def testH5OverwriteDeprecatedApi(self):
dd = ConfigDict({'t': True})
dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a')
dd = ConfigDict({'t': False})
- with TestLogging(dictdump_logger, warning=1):
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
- overwrite_data=False)
+ dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
+ overwrite_data=False)
res = h5todict(self.h5_fname)
assert(res['t'] == True)
@@ -200,8 +256,7 @@ class TestDictToH5(unittest.TestCase):
("group", "attr"): 10,
}
with h5py.File(self.h5_fname, "w") as h5file:
- with TestLogging(dictdump_logger, warning=1):
- dictdump.dicttoh5(ddict, h5file)
+ dictdump.dicttoh5(ddict, h5file)
self.assertEqual(h5file["group"].attrs['attr'], 10)
def testFlatDict(self):
@@ -241,8 +296,223 @@ class TestDictToH5(unittest.TestCase):
numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]),
ddict['darks']['0'])
-
-class TestH5ToDict(unittest.TestCase):
+ def testOverwrite(self):
+ # Tree structure that will be tested
+ group1 = {
+ ("", "attr2"): "original2",
+ "dset1": 0,
+ "dset2": [0, 1],
+ ("dset1", "attr1"): "original1",
+ ("dset1", "attr2"): "original2",
+ ("dset2", "attr1"): "original1",
+ ("dset2", "attr2"): "original2",
+ }
+ group2 = {
+ "subgroup1": group1.copy(),
+ "subgroup2": group1.copy(),
+ ("subgroup1", "attr1"): "original1",
+ ("subgroup2", "attr1"): "original1"
+ }
+ group2.update(group1)
+ # initial HDF5 tree
+ otreedict = {
+ ('', 'attr1'): "original1",
+ ('', 'attr2'): "original2",
+ 'group1': group1,
+ 'group2': group2,
+ ('group1', 'attr1'): "original1",
+ ('group2', 'attr1'): "original1"
+ }
+ wtreedict = None # dumped dictionary
+ etreedict = None # expected HDF5 tree after dump
+
+ def reset_file():
+ dicttoh5(
+ otreedict,
+ h5file=self.h5_fname,
+ mode="w",
+ )
+
+ def append_file(update_mode):
+ dicttoh5(
+ wtreedict,
+ h5file=self.h5_fname,
+ mode="a",
+ update_mode=update_mode
+ )
+
+ def assert_file():
+ rtreedict = h5todict(
+ self.h5_fname,
+ include_attributes=True,
+ asarray=False
+ )
+ netreedict = self.dictRoundTripNormalize(etreedict)
+ try:
+ self.assertRecursiveEqual(netreedict, rtreedict)
+ except AssertionError:
+ from pprint import pprint
+ print("\nDUMP:")
+ pprint(wtreedict)
+ print("\nEXPECTED:")
+ pprint(netreedict)
+ print("\nHDF5:")
+ pprint(rtreedict)
+ raise
+
+ def assert_append(update_mode):
+ append_file(update_mode)
+ assert_file()
+
+ # Test wrong arguments
+ with self.assertRaises(ValueError):
+ dicttoh5(
+ otreedict,
+ h5file=self.h5_fname,
+ mode="w",
+ update_mode="wrong-value"
+ )
+
+ # No writing
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ assert_file()
+
+ # Write identical dictionary
+ wtreedict = deepcopy(otreedict)
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify", "replace"]:
+ assert_append(update_mode)
+
+ # Write empty dictionary
+ wtreedict = dict()
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify", "replace"]:
+ assert_append(update_mode)
+
+ # Modified dataset
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
+ wtreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
+ etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+ assert_append("replace")
+
+ # Modified group
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = [0, 1]
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify"]:
+ assert_append(update_mode)
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = [0, 1]
+ assert_append("replace")
+
+ # Modified attribute
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ etreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"][("", "attr1")] = "modified"
+ assert_append("replace")
+
+ # Delete group
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"]
+ del etreedict["group2"][("subgroup2", "attr1")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ assert_append("replace")
+
+ # Delete dataset
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"]["dset2"] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"]["dset2"]
+ del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
+ del etreedict["group2"]["subgroup2"][("dset2", "attr2")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ assert_append("replace")
+
+ # Delete attribute
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"][("dset2", "attr1")] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset2"] = dict()
+ assert_append("replace")
+
+
+class TestH5ToDict(H5DictTestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
@@ -313,7 +583,7 @@ class TestH5ToDict(unittest.TestCase):
numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")])
-class TestDictToNx(unittest.TestCase):
+class TestDictToNx(H5DictTestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "nx.h5")
@@ -416,8 +686,121 @@ class TestDictToNx(unittest.TestCase):
with h5py.File(self.h5_fname, "r") as h5file:
self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10)
-
-class TestNxToDict(unittest.TestCase):
+ def testOverwrite(self):
+ entry_name = "entry"
+ wtreedict = {
+ "group1": {"a": 1, "b": 2},
+ "group2@attr3": "attr3",
+ "group2@attr4": "attr4",
+ "group2": {
+ "@attr1": "attr1",
+ "@attr2": "attr2",
+ "c": 3,
+ "d": 4,
+ "dataset4": 8,
+ "dataset4@units": "keV",
+ },
+ "group3": {"subgroup": {"e": 9, "f": 10}},
+ "dataset1": 5,
+ "dataset2": 6,
+ "dataset3": 7,
+ "dataset3@units": "mm",
+ }
+ esubtree = {
+ "@NX_class": "NXentry",
+ "group1": {"@NX_class": "NXcollection", "a": 1, "b": 2},
+ "group2": {
+ "@NX_class": "NXcollection",
+ "@attr1": "attr1",
+ "@attr2": "attr2",
+ "@attr3": "attr3",
+ "@attr4": "attr4",
+ "c": 3,
+ "d": 4,
+ "dataset4": 8,
+ "dataset4@units": "keV",
+ },
+ "group3": {
+ "@NX_class": "NXcollection",
+ "subgroup": {"@NX_class": "NXcollection", "e": 9, "f": 10},
+ },
+ "dataset1": 5,
+ "dataset2": 6,
+ "dataset3": 7,
+ "dataset3@units": "mm",
+ }
+ etreedict = {entry_name: esubtree}
+
+ def append_file(update_mode, add_nx_class):
+ dictdump.dicttonx(
+ wtreedict,
+ h5file=self.h5_fname,
+ mode="a",
+ h5path=entry_name,
+ update_mode=update_mode,
+ add_nx_class=add_nx_class
+ )
+
+ def assert_file():
+ rtreedict = dictdump.nxtodict(
+ self.h5_fname,
+ include_attributes=True,
+ asarray=False,
+ )
+ netreedict = self.dictRoundTripNormalize(etreedict)
+ try:
+ self.assertRecursiveEqual(netreedict, rtreedict)
+ except AssertionError:
+ from pprint import pprint
+ print("\nDUMP:")
+ pprint(wtreedict)
+ print("\nEXPECTED:")
+ pprint(netreedict)
+ print("\nHDF5:")
+ pprint(rtreedict)
+ raise
+
+ def assert_append(update_mode, add_nx_class=None):
+ append_file(update_mode, add_nx_class=add_nx_class)
+ assert_file()
+
+ # First to an empty file
+ assert_append(None)
+
+ # Add non-existing attributes/datasets/groups
+ wtreedict["group1"].pop("a")
+ wtreedict["group2"].pop("@attr1")
+ wtreedict["group2"]["@attr2"] = "attr3" # only for update
+ wtreedict["group2"]["@type"] = "test"
+ wtreedict["group2"]["dataset4"] = 9 # only for update
+ del wtreedict["group2"]["dataset4@units"]
+ wtreedict["group3"] = {}
+ esubtree["group2"]["@type"] = "test"
+ assert_append("add")
+
+ # Add update existing attributes and datasets
+ esubtree["group2"]["@attr2"] = "attr3"
+ esubtree["group2"]["dataset4"] = 9
+ assert_append("modify")
+
+ # Do not add missing NX_class by default when updating
+ wtreedict["group2"]["@NX_class"] = "NXprocess"
+ esubtree["group2"]["@NX_class"] = "NXprocess"
+ assert_append("modify")
+ del wtreedict["group2"]["@NX_class"]
+ assert_append("modify")
+
+ # Overwrite existing groups/datasets/attributes
+ esubtree["group1"].pop("a")
+ esubtree["group2"].pop("@attr1")
+ esubtree["group2"]["@NX_class"] = "NXcollection"
+ esubtree["group2"]["dataset4"] = 9
+ del esubtree["group2"]["dataset4@units"]
+ esubtree["group3"] = {"@NX_class": "NXcollection"}
+ assert_append("replace", add_nx_class=True)
+
+
+class TestNxToDict(H5DictTestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "nx.h5")
@@ -510,7 +893,7 @@ class TestNxToDict(unittest.TestCase):
h5todict(self.h5_fname, path="/Mars", errors='raise')
-class TestDictToJson(unittest.TestCase):
+class TestDictToJson(DictTestCase):
def setUp(self):
self.dir_path = tempfile.mkdtemp()
self.json_fname = os.path.join(self.dir_path, "cityattrs.json")
@@ -528,7 +911,7 @@ class TestDictToJson(unittest.TestCase):
self.assertIn('"inhabitants": 160215', json_content)
-class TestDictToIni(unittest.TestCase):
+class TestDictToIni(DictTestCase):
def setUp(self):
self.dir_path = tempfile.mkdtemp()
self.ini_fname = os.path.join(self.dir_path, "test.ini")
diff --git a/silx/io/test/test_h5py_utils.py b/silx/io/test/test_h5py_utils.py
new file mode 100644
index 0000000..2e2e3dd
--- /dev/null
+++ b/silx/io/test/test_h5py_utils.py
@@ -0,0 +1,397 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for h5py utilities"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "27/01/2020"
+
+
+import unittest
+import os
+import sys
+import time
+import shutil
+import tempfile
+import threading
+import multiprocessing
+from contextlib import contextmanager
+
+from .. import h5py_utils
+from ...utils.retry import RetryError, RetryTimeoutError
+
+IS_WINDOWS = sys.platform == "win32"
+
+
+def _subprocess_context_main(queue, contextmgr, *args, **kw):
+ try:
+ with contextmgr(*args, **kw):
+ queue.put(None)
+ threading.Event().wait()
+ except Exception:
+ queue.put(None)
+ raise
+
+
+@contextmanager
+def _subprocess_context(contextmgr, *args, **kw):
+ timeout = kw.pop("timeout", 10)
+ queue = multiprocessing.Queue(maxsize=1)
+ p = multiprocessing.Process(
+ target=_subprocess_context_main, args=(queue, contextmgr) + args, kwargs=kw
+ )
+ p.start()
+ try:
+ queue.get(timeout=timeout)
+ yield
+ finally:
+ try:
+ p.kill()
+ except AttributeError:
+ p.terminate()
+ p.join(timeout)
+
+
+@contextmanager
+def _open_context(filename, **kw):
+ with h5py_utils.File(filename, **kw) as f:
+ if kw.get("mode") == "w":
+ f["check"] = True
+ f.flush()
+ yield f
+
+
+def _cause_segfault():
+ import ctypes
+
+ i = ctypes.c_char(b"a")
+ j = ctypes.pointer(i)
+ c = 0
+ while True:
+ j[c] = b"a"
+ c += 1
+
+
+def _top_level_names_test(txtfilename, *args, **kw):
+ sys.stderr = open(os.devnull, "w")
+
+ with open(txtfilename, mode="r") as f:
+ failcounter = int(f.readline().strip())
+
+ ncausefailure = kw.pop("ncausefailure")
+ faildelay = kw.pop("faildelay")
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ with open(txtfilename, mode="w") as f:
+ f.write(str(failcounter))
+ if failcounter % 2:
+ raise RetryError
+ else:
+ _cause_segfault()
+ return h5py_utils._top_level_names(*args, **kw)
+
+
+top_level_names_test = h5py_utils.retry_in_subprocess()(_top_level_names_test)
+
+
+def subtests(test):
+ def wrapper(self):
+ for _ in self._subtests():
+ with self.subTest(**self._subtest_options):
+ test(self)
+
+ return wrapper
+
+
+class TestH5pyUtils(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.test_dir)
+
+ def _subtests(self):
+ self._subtest_options = {"mode": "w"}
+ self.filename_generator = self._filenames()
+ yield
+ self._subtest_options = {"mode": "w", "libver": "latest"}
+ self.filename_generator = self._filenames()
+ yield
+
+ @property
+ def _liber_allows_concurrent_access(self):
+ return self._subtest_options.get("libver") in [None, "earliest", "v18"]
+
+ def _filenames(self):
+ i = 1
+ while True:
+ filename = os.path.join(self.test_dir, "file{}.h5".format(i))
+ with self._open_context(filename):
+ pass
+ yield filename
+ i += 1
+
+ def _new_filename(self):
+ return next(self.filename_generator)
+
+ @contextmanager
+ def _open_context(self, filename, **kwargs):
+ kw = self._subtest_options
+ kw.update(kwargs)
+ with _open_context(filename, **kw) as f:
+
+ yield f
+
+ @contextmanager
+ def _open_context_subprocess(self, filename, **kwargs):
+ kw = self._subtest_options
+ kw.update(kwargs)
+ with _subprocess_context(_open_context, filename, **kw):
+ yield
+
+ def _assert_hdf5_data(self, f):
+ self.assertTrue(f["check"][()])
+
+ def _validate_hdf5_data(self, filename, swmr=False):
+ with self._open_context(filename, mode="r") as f:
+ self.assertEqual(f.swmr_mode, swmr)
+ self._assert_hdf5_data(f)
+
+ @subtests
+ def test_modes_single_process(self):
+ orig = os.environ.get("HDF5_USE_FILE_LOCKING")
+ filename1 = self._new_filename()
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+ filename2 = self._new_filename()
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+ with self._open_context(filename1, mode="r"):
+ with self._open_context(filename2, mode="r"):
+ pass
+ for mode in ["w", "a"]:
+ with self.assertRaises(RuntimeError):
+ with self._open_context(filename2, mode=mode):
+ pass
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+ with self._open_context(filename1, mode="a"):
+ for mode in ["w", "a"]:
+ with self._open_context(filename2, mode=mode):
+ pass
+ with self.assertRaises(RuntimeError):
+ with self._open_context(filename2, mode="r"):
+ pass
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+
+ @subtests
+ def test_modes_multi_process(self):
+ if not self._liber_allows_concurrent_access:
+ # A concurrent reader with HDF5_USE_FILE_LOCKING=FALSE is
+ # no longer works with HDF5 >=1.10 (you get an exception
+ # when trying to open the file)
+ return
+ filename = self._new_filename()
+
+ # File open by truncating writer
+ with self._open_context_subprocess(filename, mode="w"):
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ if IS_WINDOWS:
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ else:
+ with self.assertRaises(OSError):
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename)
+
+ # File open by appending writer
+ with self._open_context_subprocess(filename, mode="a"):
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ if IS_WINDOWS:
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ else:
+ with self.assertRaises(OSError):
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename)
+
+ # File open by reader
+ with self._open_context_subprocess(filename, mode="r"):
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename)
+
+ # File open by locking reader
+ with _subprocess_context(
+ _open_context, filename, mode="r", enable_file_locking=True
+ ):
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ if IS_WINDOWS:
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ else:
+ with self.assertRaises(OSError):
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename)
+
+ @subtests
+ @unittest.skipIf(not h5py_utils.HAS_SWMR, "SWMR not supported")
+ def test_modes_multi_process_swmr(self):
+ filename = self._new_filename()
+
+ with self._open_context(filename, mode="w", libver="latest") as f:
+ pass
+
+ # File open by SWMR writer
+ with self._open_context_subprocess(filename, mode="a", swmr=True):
+ with self._open_context(filename, mode="r") as f:
+ assert f.swmr_mode
+ self._assert_hdf5_data(f)
+ with self.assertRaises(OSError):
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename, swmr=True)
+
+ @subtests
+ def test_retry_defaults(self):
+ filename = self._new_filename()
+
+ names = h5py_utils.top_level_names(filename)
+ self.assertEqual(names, [])
+
+ names = h5py_utils.safe_top_level_names(filename)
+ self.assertEqual(names, [])
+
+ names = h5py_utils.top_level_names(filename, include_only=None)
+ self.assertEqual(names, ["check"])
+
+ names = h5py_utils.safe_top_level_names(filename, include_only=None)
+ self.assertEqual(names, ["check"])
+
+ with h5py_utils.open_item(filename, "/check", validate=lambda x: False) as item:
+ self.assertEqual(item, None)
+
+ with h5py_utils.open_item(filename, "/check", validate=None) as item:
+ self.assertTrue(item[()])
+
+ with self.assertRaises(RetryTimeoutError):
+ with h5py_utils.open_item(
+ filename,
+ "/check",
+ retry_timeout=0.1,
+ retry_invalid=True,
+ validate=lambda x: False,
+ ) as item:
+ pass
+
+ ncall = 0
+
+ def validate(item):
+ nonlocal ncall
+ if ncall >= 1:
+ return True
+ else:
+ ncall += 1
+ raise RetryError
+
+ with h5py_utils.open_item(
+ filename, "/check", validate=validate, retry_timeout=1, retry_invalid=True
+ ) as item:
+ self.assertTrue(item[()])
+
+ @subtests
+ def test_retry_custom(self):
+ filename = self._new_filename()
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @h5py_utils.retry_contextmanager()
+ def open_item(filename, name):
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise RetryError
+ with h5py_utils.File(filename) as h5file:
+ yield h5file[name]
+
+ failcounter = 0
+ kw = {"retry_timeout": sufficient_timeout}
+ with open_item(filename, "/check", **kw) as item:
+ self.assertTrue(item[()])
+
+ failcounter = 0
+ kw = {"retry_timeout": insufficient_timeout}
+ with self.assertRaises(RetryTimeoutError):
+ with open_item(filename, "/check", **kw) as item:
+ pass
+
+ @subtests
+ def test_retry_in_subprocess(self):
+ filename = self._new_filename()
+ txtfilename = os.path.join(self.test_dir, "failcounter.txt")
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ kw = {
+ "retry_timeout": sufficient_timeout,
+ "include_only": None,
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ }
+ with open(txtfilename, mode="w") as f:
+ f.write("0")
+ names = top_level_names_test(txtfilename, filename, **kw)
+ self.assertEqual(names, ["check"])
+
+ kw = {
+ "retry_timeout": insufficient_timeout,
+ "include_only": None,
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ }
+ with open(txtfilename, mode="w") as f:
+ f.write("0")
+ with self.assertRaises(RetryTimeoutError):
+ top_level_names_test(txtfilename, filename, **kw)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestH5pyUtils))
+ return test_suite
+
+
+if __name__ == "__main__":
+ unittest.main(defaultTest="suite")
diff --git a/silx/io/url.py b/silx/io/url.py
index 044977c..66b75f0 100644
--- a/silx/io/url.py
+++ b/silx/io/url.py
@@ -344,6 +344,8 @@ class DataUrl(object):
:rtype: bool
"""
file_path = self.file_path()
+ if file_path is None:
+ return False
if len(file_path) > 0:
if file_path[0] == "/":
return True
diff --git a/silx/opencl/codec/test/test_byte_offset.py b/silx/opencl/codec/test/test_byte_offset.py
index e523b0f..d1482ce 100644
--- a/silx/opencl/codec/test/test_byte_offset.py
+++ b/silx/opencl/codec/test/test_byte_offset.py
@@ -37,7 +37,7 @@ __authors__ = ["Jérôme Kieffer"]
__contact__ = "jerome.kieffer@esrf.eu"
__license__ = "MIT"
__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "10/11/2017"
+__date__ = "02/03/2021"
import sys
import time
@@ -78,8 +78,8 @@ class TestByteOffset(unittest.TestCase):
tests the byte offset decompression on GPU
"""
ref, raw = self._create_test_data(shape=(91, 97), nexcept=229)
- #ref, raw = self._create_test_data(shape=(7, 9), nexcept=0)
-
+ # ref, raw = self._create_test_data(shape=(7, 9), nexcept=0)
+
size = numpy.prod(ref.shape)
try:
@@ -104,8 +104,8 @@ class TestByteOffset(unittest.TestCase):
1000.0 * (t1 - t0),
1000.0 * (t2 - t1))
bo.log_profile()
- #print(ref)
- #print(res_cl.get())
+ # print(ref)
+ # print(res_cl.get())
self.assertEqual(delta_cy, 0, "Checks fabio works")
self.assertEqual(delta_cl, 0, "Checks opencl works")
@@ -119,7 +119,7 @@ class TestByteOffset(unittest.TestCase):
ref, raw = self._create_test_data(shape=shape, nexcept=0, lam=100)
try:
- bo = byte_offset.ByteOffset(len(raw), size, profile=False)
+ bo = byte_offset.ByteOffset(len(raw), size, profile=True)
except (RuntimeError, pyopencl.RuntimeError) as err:
logger.warning(err)
if sys.platform == "darwin":
@@ -155,6 +155,7 @@ class TestByteOffset(unittest.TestCase):
logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
1000.0 * (t1 - t0),
1000.0 * (t2 - t1))
+ bo.log_profile(stats=True)
def test_encode(self):
"""Test byte offset compression"""
diff --git a/silx/opencl/common.py b/silx/opencl/common.py
index 002c15d..b66b7b7 100644
--- a/silx/opencl/common.py
+++ b/silx/opencl/common.py
@@ -4,7 +4,7 @@
# Project: S I L X project
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2018 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2021 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -55,7 +55,7 @@ else:
try:
import pyopencl
except ImportError:
- logger.warning("Unable to import pyOpenCl. Please install it from: http://pypi.python.org/pypi/pyopencl")
+ logger.warning("Unable to import pyOpenCl. Please install it from: https://pypi.org/project/pyopencl")
pyopencl = None
else:
try:
diff --git a/silx/opencl/processing.py b/silx/opencl/processing.py
index 470b141..8b81f7f 100644
--- a/silx/opencl/processing.py
+++ b/silx/opencl/processing.py
@@ -34,19 +34,18 @@
Common OpenCL abstract base classe for different processing
"""
-from __future__ import absolute_import, print_function, division
-
__author__ = "Jerome Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "MIT"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "04/12/2020"
+__date__ = "02/03/2021"
__status__ = "stable"
+import sys
import os
import logging
import gc
-from collections import namedtuple
+from collections import namedtuple, OrderedDict
import numpy
import threading
from .common import ocl, pyopencl, release_cl_buffers, query_kernel_info, allocate_texture, check_textures_availability
@@ -342,20 +341,45 @@ class OpenclProcessing(object):
ev = pyopencl.enqueue_copy(*copy_args, **copy_kwargs)
self.profile_add(ev, "Transfer to texture")
- def log_profile(self):
+ def log_profile(self, stats=False):
"""If we are in profiling mode, prints out all timing for every single OpenCL call
+
+ :param stats: if True, prints the statistics on each kernel instead of all execution timings
+ :return: list of lines to print
"""
- t = 0.0
- out = ["", "Profiling info for OpenCL %s" % self.__class__.__name__]
+ total_time = 0.0
+ out = [""]
+ if stats:
+ stats = OrderedDict()
+ out.append(f"OpenCL kernel profiling statistics in milliseconds for: {self.__class__.__name__}")
+ out.append(f"{'Kernel name':>50} (count): min median max mean std")
+ else:
+ stats = None
+ out.append(f"Profiling info for OpenCL: {self.__class__.__name__}")
+
if self.profile:
for e in self.events:
if "__len__" in dir(e) and len(e) >= 2:
- et = 1e-6 * (e[1].profile.end - e[1].profile.start)
- out.append("%50s:\t%.3fms" % (e[0], et))
- t += et
+ name = e[0]
+ pr = e[1].profile
+ t0 = pr.start
+ t1 = pr.end
+ et = 1e-6 * (t1 - t0)
+ total_time += et
+ if stats is None:
+ out.append(f"{name:>50} : {et:.3f}ms")
+ else:
+ if name in stats:
+ stats[name].append(et)
+ else:
+ stats[name] = [et]
+ if stats is not None:
+ for k, v in stats.items():
+ n = numpy.array(v)
+ out.append(f"{k:>50} ({len(v):5}): {n.min():8.3f} {numpy.median(n):8.3f} {n.max():8.3f} {n.mean():8.3f} {n.std():8.3f}")
+ out.append("_" * 80)
+ out.append(f"{'Total OpenCL execution time':>50} : {total_time:.3f}ms")
- out.append("_" * 80)
- out.append("%50s:\t%.3fms" % ("Total execution time", t))
logger.info(os.linesep.join(out))
return out
diff --git a/silx/opencl/statistics.py b/silx/opencl/statistics.py
index bd8e7b7..a96ee33 100644
--- a/silx/opencl/statistics.py
+++ b/silx/opencl/statistics.py
@@ -29,13 +29,10 @@
large data where numpy is not very efficient.
"""
-from __future__ import absolute_import, print_function, with_statement, division
-
-
__author__ = "Jerome Kieffer"
__license__ = "MIT"
-__date__ = "11/01/2019"
-__copyright__ = "2012-2017, ESRF, Grenoble"
+__date__ = "19/05/2021"
+__copyright__ = "2012-2019, ESRF, Grenoble"
__contact__ = "jerome.kieffer@esrf.fr"
import logging
@@ -128,7 +125,7 @@ class Statistics(OpenclProcessing):
self.kernel_files,
"-D NIMAGE=%i" % self.size)
compiler_options = self.get_compiler_options(x87_volatile=True)
- src = concatenate_cl_kernel(("kahan.cl", "statistics.cl"))
+ src = concatenate_cl_kernel(("doubleword.cl", "statistics.cl"))
self.reduction_comp = ReductionKernel(self.ctx,
dtype_out=float8,
neutral=zero8,
@@ -146,6 +143,19 @@ class Statistics(OpenclProcessing):
preamble=src,
options=compiler_options)
+ if "cl_khr_fp64" in self.device.extensions:
+ self.reduction_double = ReductionKernel(self.ctx,
+ dtype_out=float8,
+ neutral=zero8,
+ map_expr="map_statistics(data, i)",
+ reduce_expr="reduce_statistics_double(a,b)",
+ arguments="__global float *data",
+ preamble=src,
+ options=compiler_options)
+ else:
+ logger.info("Device %s does not support double-precision arithmetics, fall-back on compensated one", self.device)
+ self.reduction_double = self.reduction_comp
+
def send_buffer(self, data, dest):
"""
Send a numpy array to the device, including the cast on the device if
@@ -154,7 +164,7 @@ class Statistics(OpenclProcessing):
:param numpy.ndarray data: numpy array with data
:param dest: name of the buffer as registered in the class
"""
-
+ logger.info("send data to %s", dest)
dest_type = numpy.dtype([i.dtype for i in self.buffers if i.name == dest][0])
events = []
if (data.dtype == dest_type) or (data.dtype.itemsize > dest_type.itemsize):
@@ -173,8 +183,8 @@ class Statistics(OpenclProcessing):
self.cl_mem["raw"].data,
self.cl_mem[dest].data)
events += [
- EventDescription("copy H->D %s" % dest, copy_image),
- EventDescription("cast to float", cast_to_float)
+ EventDescription("copy H->D raw", copy_image),
+ EventDescription(f"cast to float {dest}", cast_to_float)
]
if self.profile:
self.events += events
@@ -193,16 +203,24 @@ class Statistics(OpenclProcessing):
size = data.size
assert size <= self.size, "size is OK"
events = []
+ if comp is True:
+ comp = "comp"
+ elif comp is False:
+ comp = "single"
+ else:
+ comp = comp.lower()
with self.sem:
self.send_buffer(data, "converted")
- if comp:
- reduction = self.reduction_comp
- else:
+ if comp in ("single", "fp32", "float32"):
reduction = self.reduction_simple
+ elif comp in ("double", "fp64", "float64"):
+ reduction = self.reduction_double
+ else:
+ reduction = self.reduction_comp
res_d, evt = reduction(self.cl_mem["converted"][:self.size],
queue=self.queue,
return_event=True)
- events.append(EventDescription("statistical reduction %s" % ("comp"if comp else "simple"), evt))
+ events.append(EventDescription(f"statistical reduction {comp}", evt))
if self.profile:
self.events += events
res_h = res_d.get()
diff --git a/silx/opencl/test/__init__.py b/silx/opencl/test/__init__.py
index 2e90e66..928dbaf 100644
--- a/silx/opencl/test/__init__.py
+++ b/silx/opencl/test/__init__.py
@@ -24,7 +24,7 @@
__authors__ = ["J. Kieffer"]
__license__ = "MIT"
-__date__ = "11/01/2019"
+__date__ = "17/05/2021"
import os
import unittest
@@ -37,6 +37,7 @@ from . import test_array_utils
from ..codec import test as test_codec
from . import test_image
from . import test_kahan
+from . import test_doubleword
from . import test_stats
from . import test_convolution
from . import test_sparse
@@ -53,6 +54,7 @@ def suite():
test_suite.addTests(test_codec.suite())
test_suite.addTests(test_image.suite())
test_suite.addTests(test_kahan.suite())
+ test_suite.addTests(test_doubleword.suite())
test_suite.addTests(test_stats.suite())
test_suite.addTests(test_convolution.suite())
test_suite.addTests(test_sparse.suite())
diff --git a/silx/opencl/test/test_doubleword.py b/silx/opencl/test/test_doubleword.py
new file mode 100644
index 0000000..ca947e0
--- /dev/null
+++ b/silx/opencl/test/test_doubleword.py
@@ -0,0 +1,258 @@
+#!/usr/bin/env python
+# coding: utf-8
+#
+# Project: The silx project
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2021-2021 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+"test suite for OpenCL code"
+
+__author__ = "Jérôme Kieffer"
+__contact__ = "Jerome.Kieffer@ESRF.eu"
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "31/05/2021"
+
+import unittest
+import numpy
+import logging
+import platform
+
+logger = logging.getLogger(__name__)
+try:
+ import pyopencl
+except ImportError as error:
+ logger.warning("OpenCL module (pyopencl) is not present, skip tests. %s.", error)
+ pyopencl = None
+
+from .. import ocl
+if ocl is not None:
+ from ..utils import read_cl_file
+ from .. import pyopencl
+ import pyopencl.array
+ from pyopencl.elementwise import ElementwiseKernel
+from ...test.utils import test_options
+
+EPS32 = numpy.finfo("float32").eps
+EPS64 = numpy.finfo("float64").eps
+
+
+class TestDoubleWord(unittest.TestCase):
+ """
+ Test the kernels for compensated math in OpenCL
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ if not test_options.WITH_OPENCL_TEST:
+ raise unittest.SkipTest("User request to skip OpenCL tests")
+ if pyopencl is None or ocl is None:
+ raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
+
+ cls.ctx = ocl.create_context(devicetype="GPU")
+ cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+
+ # this is running 32 bits OpenCL woth POCL
+ if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
+ cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ cls.args = "-DX87_VOLATILE=volatile"
+ else:
+ cls.args = ""
+ size = 1024
+ cls.a = 1.0 + numpy.random.random(size)
+ cls.b = 1.0 + numpy.random.random(size)
+ cls.ah = cls.a.astype(numpy.float32)
+ cls.bh = cls.b.astype(numpy.float32)
+ cls.al = (cls.a - cls.ah).astype(numpy.float32)
+ cls.bl = (cls.b - cls.bh).astype(numpy.float32)
+ cls.doubleword = read_cl_file("doubleword.cl")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.queue = None
+ cls.ctx = None
+ cls.a = cls.al = cls.ah = None
+ cls.b = cls.bl = cls.bh = None
+ cls.doubleword = None
+
+ def test_fast_sum2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fast_fp_plus_fp(a[i], b[i]); res_h[i] = tmp.s0; res_l[i] = tmp.s1",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ self.assertEqual(abs(self.ah + self.bl - res_h.get()).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bl - res_h.get()).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bl - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+
+ def test_sum2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_plus_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ self.assertEqual(abs(self.ah + self.bh - res_h.get()).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bh - res_h.get()).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bh - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+
+ def test_prod2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_times_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertEqual(abs(self.ah * self.bh - res_m).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) * self.bh - res_m).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) * self.bh - res).max(), 0, "Exact matches")
+
+ def test_dw_plus_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a + self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a + self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.ah.astype(numpy.float64) + self.al + self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_plus_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a + self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a + self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a + self.b - res).max(), 3 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_times_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_times_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a * self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a * self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a * self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_times_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_times_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a * self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a * self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a * self.b - res).max(), 5 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_div_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_div_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a / self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a / self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a / self.bh - res).max(), 3 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_div_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_div_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a / self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a / self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a / self.b - res).max(), 6 * EPS32 ** 2, "Exact matches")
+
+
+def suite():
+ testsuite = unittest.TestSuite()
+ loader = unittest.defaultTestLoader.loadTestsFromTestCase
+ testsuite.addTest(loader(TestDoubleWord))
+ return testsuite
+
+
+if __name__ == '__main__':
+ runner = unittest.TextTestRunner()
+ runner.run(suite())
diff --git a/silx/opencl/test/test_kahan.py b/silx/opencl/test/test_kahan.py
index 167640c..6ea599b 100644
--- a/silx/opencl/test/test_kahan.py
+++ b/silx/opencl/test/test_kahan.py
@@ -1,10 +1,10 @@
#!/usr/bin/env python
# coding: utf-8
#
-# Project: Azimuthal integration
-# https://github.com/silx-kit/pyFAI
+# Project: OpenCL numerical library
+# https://github.com/silx-kit/silx
#
-# Copyright (C) 2015-2019 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2015-2021 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -28,13 +28,11 @@
"test suite for OpenCL code"
-from __future__ import absolute_import, division, print_function
-
__author__ = "Jérôme Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "MIT"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "01/08/2019"
+__date__ = "17/05/2021"
import unittest
diff --git a/silx/opencl/test/test_stats.py b/silx/opencl/test/test_stats.py
index b5127c8..8baf05e 100644
--- a/silx/opencl/test/test_stats.py
+++ b/silx/opencl/test/test_stats.py
@@ -28,14 +28,11 @@
"""
Simple test of an addition
"""
-
-from __future__ import division, print_function
-
__authors__ = ["Henri Payno, Jérôme Kieffer"]
__contact__ = "jerome.kieffer@esrf.eu"
__license__ = "MIT"
__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "13/12/2018"
+__date__ = "19/05/2021"
import logging
import time
@@ -58,11 +55,13 @@ class TestStatistics(unittest.TestCase):
def setUpClass(cls):
cls.size = 1 << 20 # 1 million elements
cls.data = numpy.random.randint(0, 65000, cls.size).astype("uint16")
- t0 = time.time()
- cls.ref = StatResults(cls.data.min(), cls.data.max(), cls.data.size,
- cls.data.sum(), cls.data.mean(), cls.data.std() ** 2,
- cls.data.std())
- t1 = time.time()
+ fdata = cls.data.astype("float64")
+ t0 = time.perf_counter()
+ std = fdata.std()
+ cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size),
+ fdata.sum(), fdata.mean(), std ** 2,
+ std)
+ t1 = time.perf_counter()
cls.ref_time = t1 - t0
@classmethod
@@ -89,19 +88,22 @@ class TestStatistics(unittest.TestCase):
s = Statistics(template=self.data, platformid=pid, deviceid=did)
except Exception as err:
failed_init = True
- res = StatResults(0,0,0,0,0,0,0)
+ res = StatResults(0, 0, 0, 0, 0, 0, 0)
+ print(err)
else:
failed_init = False
- t0 = time.time()
- res = s(self.data)
- t1 = time.time()
- logger.warning("failed_init %s", failed_init)
- if failed_init or not self.validate(res):
- logger.error("Failed on platform %s device %s", platform, device)
- logger.error("Reference results: %s", self.ref)
- logger.error("Faulty results: %s", res)
- self.assertTrue(False, "Stat calculation failed on %s %s" % (platform, device))
- logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
+ for comp in ("single", "double", "comp"):
+ t0 = time.perf_counter()
+ res = s(self.data, comp=comp)
+ t1 = time.perf_counter()
+ logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
+
+ if failed_init or not self.validate(res):
+ logger.error("failed_init %s; Computation modes %s", failed_init, comp)
+ logger.error("Failed on platform %s device %s", platform, device)
+ logger.error("Reference results: %s", self.ref)
+ logger.error("Faulty results: %s", res)
+ self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}")
def suite():
diff --git a/silx/resources/opencl/doubleword.cl b/silx/resources/opencl/doubleword.cl
new file mode 100644
index 0000000..a0ebfda
--- /dev/null
+++ b/silx/resources/opencl/doubleword.cl
@@ -0,0 +1,115 @@
+/*
+ * OpenCL library for double word floating point calculation using compensated arithmetics
+ *
+ * The theoritical basis can be found in Valentina Popescu's PhD thesis:
+ * Towards fast and certified multi-precision libraries
+ * Reference LYSEN036
+ * http://www.theses.fr/2017LYSEN036
+ * All page number and equation number are refering to this document.
+ *
+ * The precision of the calculation (bounds) is provided in ULP (smallest possible mantissa)
+ * and come from the table 2.2 (page 68 of the thesis).
+ * The number of equivalent FLOP is taken from the table 2.3 (page 69 the thesis).
+ * Note that FLOP are not all equal: a division is much more expensive than an addition.
+ */
+
+//This library can be expanded to double-double by redefining fp, fp2 and one to double, double2 and 1.0.
+#ifdef DOUBLEDOUBLE
+#define fp double
+#define fp2 double2
+#define one 1.0
+#else
+#define fp float
+#define fp2 float2
+#define one 1.0f
+#endif
+
+/* Nota: i386 computer use x87 registers which are larger than the 32bits precision
+ * which can invalidate the error compensation mechanism.
+ *
+ * We use the trick to declare some variable "volatile" to enforce the actual
+ * precision reduction of those variables.
+*/
+
+#ifndef X87_VOLATILE
+# define X87_VOLATILE
+#endif
+
+//Algorithm 1, p23, theorem 1.1.12. Requires e_x > e_y, valid if |x| > |y|
+inline fp2 fast_fp_plus_fp(fp x, fp y){
+ X87_VOLATILE fp s = x + y;
+ X87_VOLATILE fp z = s - x;
+ fp e = y - z;
+ return (fp2)(s, e);
+}
+
+//Algorithm 2, p24, same as fast_fp_plus_fp without the condition on e_x and e_y
+inline fp2 fp_plus_fp(fp x, fp y){
+ X87_VOLATILE fp s = x + y;
+ X87_VOLATILE fp xp = s - y;
+ X87_VOLATILE fp yp = s - xp;
+ X87_VOLATILE fp dx = x - xp;
+ X87_VOLATILE fp dy = y - yp;
+ return (fp2)(s, dx+dy);
+}
+
+//Algorithm 3, p24: multiplication with a FMA
+inline fp2 fp_times_fp(fp x, fp y){
+ fp p = x * y;
+ fp e = fma(x, y, -p);
+ return (fp2)(p, e);
+}
+
+//Algorithm 7, p38: Addition of a FP to a DW. 10flop bounds:2u²+5u³
+inline fp2 dw_plus_fp(fp2 x, fp y){
+ fp2 s = fp_plus_fp(x.s0, y);
+ X87_VOLATILE fp v = x.s1 + s.s1;
+ return fast_fp_plus_fp(s.s0, v);
+}
+
+//Algorithm 9, p40: addition of two DW: 20flop bounds:3u²+13u³
+inline fp2 dw_plus_dw(fp2 x, fp2 y){
+ fp2 s = fp_plus_fp(x.s0, y.s0);
+ fp2 t = fp_plus_fp(x.s1, y.s1);
+ fp2 v = fast_fp_plus_fp(s.s0, s.s1 + t.s0);
+ return fast_fp_plus_fp(v.s0, t.s1 + v.s1);
+}
+
+//Algorithm 12, p49: Multiplication FP*DW: 6flops bounds:2u²
+inline fp2 dw_times_fp(fp2 x, fp y){
+ fp2 c = fp_times_fp(x.s0, y);
+ return fast_fp_plus_fp(c.s0, fma(x.s1, y, c.s1));
+}
+
+//Algorithm 14, p52: Multiplication DW*DW, 8 flops bounds:6u²
+inline fp2 dw_times_dw(fp2 x, fp2 y){
+ fp2 c = fp_times_fp(x.s0, y.s0);
+ X87_VOLATILE fp l = fma(x.s1, y.s0, x.s0 * y.s1);
+ return fast_fp_plus_fp(c.s0, c.s1 + l);
+}
+
+//Algorithm 17, p55: Division DW / FP, 10flops bounds: 3.5u²
+inline fp2 dw_div_fp(fp2 x, fp y){
+ X87_VOLATILE fp th = x.s0 / y;
+ fp2 pi = fp_times_fp(th, y);
+ fp2 d = x - pi;
+ X87_VOLATILE fp delta = d.s0 + d.s1;
+ X87_VOLATILE fp tl = delta/y;
+ return fast_fp_plus_fp(th, tl);
+}
+
+//Derived from algorithm 20, p64: Inversion 1/ DW, 22 flops
+inline fp2 inv_dw(fp2 y){
+ X87_VOLATILE fp th = one/y.s0;
+ X87_VOLATILE fp rh = fma(-y.s0, th, one);
+ X87_VOLATILE fp rl = -y.s1 * th;
+ fp2 e = fast_fp_plus_fp(rh, rl);
+ fp2 delta = dw_times_fp(e, th);
+ return dw_plus_fp(delta, th);
+}
+
+//Algorithm 20, p64: Division DW / DW, 30 flops: bounds:9.8u²
+inline fp2 dw_div_dw(fp2 x, fp2 y){
+ return dw_times_dw(x, inv_dw(y));
+}
+
diff --git a/silx/resources/opencl/statistics.cl b/silx/resources/opencl/statistics.cl
index c7d98db..47d925b 100644
--- a/silx/resources/opencl/statistics.cl
+++ b/silx/resources/opencl/statistics.cl
@@ -3,11 +3,11 @@
*
*
*
- * Copyright (C) 2012-2017 European Synchrotron Radiation Facility
+ * Copyright (C) 2012-2021 European Synchrotron Radiation Facility
* Grenoble, France
*
* Principal authors: J. Kieffer (kieffer@esrf.fr)
- * Last revision: 13/12/2018
+ * Last revision: 17/05/2021
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -33,10 +33,29 @@
*
* \brief OpenCL kernels for min, max, mean and std calculation
*
- * Constant to be provided at build time:
+ * This module provides two functions to perform the `map` and the `reduce`
+ * to be used with pyopencl reduction to calculate in a single pass the minimum,
+ * maximum, sum, count, mean and standart deviation for an array.
+ *
+ * So beside the reduction mechanisme from pyopencl, this algorithm implementes equations from
+ * https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
+ *
+ * let A and B be 2 disjoint partition of all elements
+ *
+ * Omega_A = sum_{i \in A}(omaga_i) The sum of all weights
+ * V_A is the weighted sum of the signal over the partition
+ * VV_A is the weighted sum of deviation squarred
+ *
+ * With this the mean is V / Omega and the variance equals VV / omega.
+ *
+ * Redction operator performs:
+ * Omega_{AB} = Omega_A + Omega_B
+ * V_{AB} = V_A + V_B
+ * VV{AB} = VV_A + VV_B + (Omega_A*V_B-Omega_B*V_A)² / (Omega_A * Omega_B * Omega_{AB})
+ *
+ * To avoid any numerical degradation, the doubleword library is used to perform all floating point operations.
*
*/
-
#include "for_eclipse.h"
/* \brief read a value at given position and initialize the float8 for the reduction
@@ -44,12 +63,12 @@
* The float8 returned contains:
* s0: minimum value
* s1: maximum value
- * s2: count number of valid pixels
- * s3: count (error associated to)
- * s4: sum of valid pixels
- * s5: sum (error associated to)
- * s6: variance*count
- * s7: variance*count (error associated to)
+ * s2: Omega_h count number of valid pixels
+ * s3: Omega_l error associated to the count
+ * s4: V_h sum of signal
+ * s5: V_l error associated to the sum of signal
+ * s6: VVh variance*count
+ * s7: VVl error associated to variance*count
*
*/
static inline float8 map_statistics(global float* data, int position)
@@ -60,11 +79,12 @@ static inline float8 map_statistics(global float* data, int position)
if (isfinite(value))
{
result = (float8)(value, value, 1.0f, 0.0f, value, 0.0f, 0.0f, 0.0f);
- // min max cnt cnt_err sum sum_err M M_err
+ // min max cnt_h cnt_l sum_h sum_l M2_h M2_l
}
else
{
result = (float8)(FLT_MAX, -FLT_MAX, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+ // min max cnt_h cnt_l sum_h sum_l M2_h M2_l
}
return result;
}
@@ -89,7 +109,8 @@ static inline float8 map_statistics(global float* data, int position)
static inline float8 reduce_statistics(float8 a, float8 b)
{
float2 sum_a, sum_b, M_a, M_b, count_a, count_b;
-
+ float2 count, sum;
+ float2 M, delta, delta2, omega3;
//test on count
if (a.s2 == 0.0f)
{
@@ -113,39 +134,22 @@ static inline float8 reduce_statistics(float8 a, float8 b)
M_b = (float2)(b.s6, b.s7);
}
// count = count_a + count_b
- float2 count = compensated_sum(count_a, count_b);
+ count = dw_plus_dw(count_a, count_b);
// sum = sum_a + sum_b
- float2 sum = compensated_sum(sum_a, sum_b);
-
- //delta = avg_b - avg_a
- //delta^2 = avg_b^2 + avg_a^2 - 2*avg_b*avg_a
- //coount_a*count_b*delta^2 = count_a/count_b * sum_b^2 + count_b/count_a*sum_a^2 - 2*sum_a*sum_b
-
- //float2 sum2_a = compensated_mul(sum_a, sum_a);
- //float2 sum2_b = compensated_mul(sum_b, sum_b);
- //float2 ca_over_cb = compensated_mul(count_a, compensated_inv(count_b));
- //float2 cb_over_ca = compensated_mul(count_b, compensated_inv(count_a));
-
- //float2 delta2cbca = compensated_sum(compensated_sum(
- // compensated_mul(ca_over_cb, sum2_b),
- // compensated_mul(cb_over_ca, sum2_a)),
- // -2.0f * compensated_mul(sum_a, sum_b));
-//////////////
-// float2 delta = compensated_sum(
-// compensated_mul(sum_b, compensated_inv(count_b)),
-// -1*(compensated_mul(sum_a, compensated_inv(count_a))));
- float2 delta = compensated_sum(compensated_div(sum_b, count_b),
- -1*compensated_div(sum_a, count_a));
-
- float2 delta2cbca = compensated_mul(compensated_mul(delta, delta),
- compensated_mul(count_a, count_b));
- float2 M2 = compensated_sum(compensated_sum(M_a, M_b),
- compensated_mul(delta2cbca, compensated_inv(count)));
- //M2 = M_a + M_b + delta ** 2 * count_a * count_b / (count_a + count_b)
+ sum = dw_plus_dw(sum_a, sum_b);
+
+ // M2 = M_a + M_b + (Omega_A*V_B-Omega_B*V_A)² / (Omega_A * Omega_B * Omega_{AB})
+ M = dw_plus_dw(M_a, M_b);
+ delta = dw_plus_dw(dw_times_dw(count_b, sum_a),
+ -dw_times_dw(count_a, sum_b));
+ delta2 = dw_times_dw(delta, delta);
+ omega3 = dw_times_dw(count, dw_times_dw(count_a, count_b));
+ M = dw_plus_dw(M, dw_div_dw(delta2, omega3));
+
float8 result = (float8)(min(a.s0, b.s0), max(a.s1, b.s1),
count.s0, count.s1,
sum.s0, sum.s1,
- M2.s0, M2.s1);
+ M.s0, M.s1);
return result;
}
@@ -157,12 +161,12 @@ static inline float8 reduce_statistics(float8 a, float8 b)
* The float8 used here contain contains:
* s0: minimum value
* s1: maximum value
- * s2: count number of valid pixels
- * s3: count (error associated to)
- * s4: sum of valid pixels
- * s5: sum (error associated to)
- * s6: M=variance*(count-1)
- * s7: M=variance*(count-1) (error associated to)
+ * s2: count number of valid pixels (major)
+ * s3: count number of valid pixels (minor)
+ * s4: sum of valid pixels (major)
+ * s5: sum of valid pixels (minor)
+ * s6: variance*count (major)
+ * s7: variance*count (minor)
*
*/
@@ -194,10 +198,9 @@ static inline float8 reduce_statistics_simple(float8 a, float8 b)
}
float count = count_a + count_b;
float sum = sum_a + sum_b;
- float delta = sum_a/count_a - sum_b/count_b;
- float delta2cbca = count_b * count_a * delta * delta;
- float M2 = M_a + M_b + delta2cbca/count;
- //M2 = M_a + M_b + delta ** 2 * count_a * count_b / (count_a + count_b)
+ float delta = sum_a*count_b - sum_b*count_a;
+ float delta2 = delta * delta;
+ float M2 = M_a + M_b + delta2/(count*count_a*count_b);
float8 result = (float8)(min(a.s0, b.s0), max(a.s1, b.s1),
count, 0.0f,
sum, 0.0f,
@@ -206,3 +209,75 @@ static inline float8 reduce_statistics_simple(float8 a, float8 b)
}
+#ifdef cl_khr_fp64
+#pragma OPENCL EXTENSION cl_khr_fp64 : enable
+
+
+// unpack a double in two floats such as the sum of the two is the double number
+static inline float2 unpack_double(double inp){
+ float major = (float) inp;
+ float minor = (float) (inp - major);
+ return (float2)(major, minor);
+}
+
+// pack two floats into a double
+static inline double pack_double(float major, float minor){
+ return (double)major + (double)minor;
+}
+
+/* \brief reduction function associated to the statistics using double precision arithmetics
+ *
+ * this is described in:
+ * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+ *
+ * The float8 used here contain contains:
+ * s0: minimum value
+ * s1: maximum value
+ * s2: count number of valid pixels (major)
+ * s3: count number of valid pixels (minor)
+ * s4: sum of valid pixels (major)
+ * s5: sum of valid pixels (minor)
+ * s6: variance*count (major)
+ * s7: variance*count (minor)
+ *
+ */
+
+static inline float8 reduce_statistics_double(float8 a, float8 b)
+{
+ double sum_a, sum_b, M_a, M_b, count_a, count_b;
+
+ //test on count
+ if (a.s2 == 0.0)
+ {
+ return b;
+ }
+ else
+ {
+ count_a = pack_double(a.s2, a.s3);
+ sum_a = pack_double(a.s4,a.s5);
+ M_a = pack_double(a.s6, a.s7);
+ }
+ //test on count
+ if (b.s2 == 0.0)
+ {
+ return a;
+ }
+ else
+ {
+ count_b = pack_double(b.s2, b.s3);
+ sum_b = pack_double(b.s4, b.s5);
+ M_b = pack_double(b.s6, b.s7);
+ }
+ double count = count_a + count_b;
+ double sum = sum_a + sum_b;
+ double delta = sum_a*count_b - sum_b*count_a;
+ double delta2 = delta * delta;
+ double M2 = M_a + M_b + delta2/(count*count_a*count_b);
+ float8 result = (float8)((float2)(min(a.s0, b.s0), max(a.s1, b.s1)),
+ unpack_double(count),
+ unpack_double( sum),
+ unpack_double( M2));
+ return result;
+}
+
+#endif \ No newline at end of file
diff --git a/silx/test/__init__.py b/silx/test/__init__.py
index 0991469..2063ab5 100644
--- a/silx/test/__init__.py
+++ b/silx/test/__init__.py
@@ -89,10 +89,13 @@ def suite():
return test_suite
-def run_tests():
- """Run test complete test_suite"""
+def run_tests(*args, **kwargs):
+ """Run test complete test_suite
+
+ Provided arguments are passed to :class:`unittest.TextTestRunner`.
+ """
test_options.configure()
- runner = unittest.TextTestRunner()
+ runner = unittest.TextTestRunner(*args, **kwargs)
if not runner.run(suite()).wasSuccessful():
print("Test suite failed")
return 1
diff --git a/silx/utils/retry.py b/silx/utils/retry.py
new file mode 100644
index 0000000..adc43bc
--- /dev/null
+++ b/silx/utils/retry.py
@@ -0,0 +1,264 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides utility methods for retrying methods until they
+no longer fail.
+"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "05/02/2020"
+
+
+import time
+from functools import wraps
+from contextlib import contextmanager
+import multiprocessing
+from queue import Empty
+
+
+RETRY_PERIOD = 0.01
+
+
+class RetryTimeoutError(TimeoutError):
+ pass
+
+
+class RetryError(RuntimeError):
+ pass
+
+
+def _default_retry_on_error(e):
+ """
+ :param BaseException e:
+ :returns bool:
+ """
+ return isinstance(e, RetryError)
+
+
+@contextmanager
+def _handle_exception(options):
+ try:
+ yield
+ except BaseException as e:
+ retry_on_error = options.get("retry_on_error")
+ if retry_on_error is not None and retry_on_error(e):
+ options["exception"] = e
+ else:
+ raise
+
+
+def _retry_loop(retry_timeout=None, retry_period=None, retry_on_error=None):
+ """Iterator which is endless or ends with an RetryTimeoutError.
+ It yields a dictionary which can be used to influence the loop.
+
+ :param num retry_timeout:
+ :param num retry_period: sleep before retry
+ :param callable or None retry_on_error: checks whether an exception is
+ eligible for retry
+ """
+ has_timeout = retry_timeout is not None
+ options = {"exception": None, "retry_on_error": retry_on_error}
+ if has_timeout:
+ t0 = time.time()
+ while True:
+ yield options
+ if retry_period is not None:
+ time.sleep(retry_period)
+ if has_timeout and (time.time() - t0) > retry_timeout:
+ raise RetryTimeoutError from options.get("exception")
+
+
+def retry(
+ retry_timeout=None, retry_period=None, retry_on_error=_default_retry_on_error
+):
+ """Decorator for a method that needs to be executed until it not longer
+ fails or until `retry_on_error` returns False.
+
+ The decorator arguments can be overriden by using them when calling the
+ decorated method.
+
+ :param num retry_timeout:
+ :param num retry_period: sleep before retry
+ :param callable or None retry_on_error: checks whether an exception is
+ eligible for retry
+ """
+
+ if retry_period is None:
+ retry_period = RETRY_PERIOD
+
+ def decorator(method):
+ @wraps(method)
+ def wrapper(*args, **kw):
+ _retry_timeout = kw.pop("retry_timeout", retry_timeout)
+ _retry_period = kw.pop("retry_period", retry_period)
+ _retry_on_error = kw.pop("retry_on_error", retry_on_error)
+ for options in _retry_loop(
+ retry_timeout=_retry_timeout,
+ retry_period=_retry_period,
+ retry_on_error=_retry_on_error,
+ ):
+ with _handle_exception(options):
+ return method(*args, **kw)
+
+ return wrapper
+
+ return decorator
+
+
+def retry_contextmanager(
+ retry_timeout=None, retry_period=None, retry_on_error=_default_retry_on_error
+):
+ """Decorator to make a context manager from a method that needs to be
+ entered until it no longer fails or until `retry_on_error` returns False.
+
+ The decorator arguments can be overriden by using them when calling the
+ decorated method.
+
+ :param num retry_timeout:
+ :param num retry_period: sleep before retry
+ :param callable or None retry_on_error: checks whether an exception is
+ eligible for retry
+ """
+
+ if retry_period is None:
+ retry_period = RETRY_PERIOD
+
+ def decorator(method):
+ @wraps(method)
+ def wrapper(*args, **kw):
+ _retry_timeout = kw.pop("retry_timeout", retry_timeout)
+ _retry_period = kw.pop("retry_period", retry_period)
+ _retry_on_error = kw.pop("retry_on_error", retry_on_error)
+ for options in _retry_loop(
+ retry_timeout=_retry_timeout,
+ retry_period=_retry_period,
+ retry_on_error=_retry_on_error,
+ ):
+ with _handle_exception(options):
+ gen = method(*args, **kw)
+ result = next(gen)
+ options["retry_on_error"] = None
+ yield result
+ try:
+ next(gen)
+ except StopIteration:
+ return
+ else:
+ raise RuntimeError(str(method) + " should only yield once")
+
+ return contextmanager(wrapper)
+
+ return decorator
+
+
+def _subprocess_main(queue, method, retry_on_error, *args, **kw):
+ try:
+ result = method(*args, **kw)
+ except BaseException as e:
+ if retry_on_error(e):
+ # As the traceback gets lost, make sure the top-level
+ # exception is RetryError
+ e = RetryError(str(e))
+ queue.put(e)
+ else:
+ queue.put(result)
+
+
+def retry_in_subprocess(
+ retry_timeout=None, retry_period=None, retry_on_error=_default_retry_on_error
+):
+ """Same as `retry` but it also retries segmentation faults.
+
+ As subprocesses are spawned, you cannot use this decorator with the "@" syntax
+ because the decorated method needs to be an attribute of a module:
+
+ .. code-block:: python
+
+ def _method(*args, **kw):
+ ...
+
+ method = retry_in_subprocess()(_method)
+
+ :param num retry_timeout:
+ :param num retry_period: sleep before retry
+ :param callable or None retry_on_error: checks whether an exception is
+ eligible for retry
+ """
+
+ if retry_period is None:
+ retry_period = RETRY_PERIOD
+
+ def decorator(method):
+ @wraps(method)
+ def wrapper(*args, **kw):
+ _retry_timeout = kw.pop("retry_timeout", retry_timeout)
+ _retry_period = kw.pop("retry_period", retry_period)
+ _retry_on_error = kw.pop("retry_on_error", retry_on_error)
+
+ ctx = multiprocessing.get_context("spawn")
+
+ def start_subprocess():
+ queue = ctx.Queue(maxsize=1)
+ p = ctx.Process(
+ target=_subprocess_main,
+ args=(queue, method, retry_on_error) + args,
+ kwargs=kw,
+ )
+ p.start()
+ return p, queue
+
+ def stop_subprocess(p):
+ try:
+ p.kill()
+ except AttributeError:
+ p.terminate()
+ p.join()
+
+ p, queue = start_subprocess()
+ try:
+ for options in _retry_loop(
+ retry_timeout=_retry_timeout, retry_on_error=_retry_on_error
+ ):
+ with _handle_exception(options):
+ if not p.is_alive():
+ p, queue = start_subprocess()
+ try:
+ result = queue.get(block=True, timeout=_retry_period)
+ except Empty:
+ pass
+ except ValueError:
+ pass
+ else:
+ if isinstance(result, BaseException):
+ stop_subprocess(p)
+ raise result
+ else:
+ return result
+ finally:
+ stop_subprocess(p)
+
+ return wrapper
+
+ return decorator
diff --git a/silx/utils/test/__init__.py b/silx/utils/test/__init__.py
index 252bc05..b35feee 100755
--- a/silx/utils/test/__init__.py
+++ b/silx/utils/test/__init__.py
@@ -39,6 +39,7 @@ from . import test_number
from . import test_external_resources
from . import test_enum
from . import test_testutils
+from . import test_retry
def suite():
@@ -54,4 +55,5 @@ def suite():
test_suite.addTest(test_external_resources.suite())
test_suite.addTest(test_enum.suite())
test_suite.addTest(test_testutils.suite())
+ test_suite.addTest(test_retry.suite())
return test_suite
diff --git a/silx/utils/test/test_retry.py b/silx/utils/test/test_retry.py
new file mode 100644
index 0000000..d223f44
--- /dev/null
+++ b/silx/utils/test/test_retry.py
@@ -0,0 +1,179 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for retry utilities"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "05/02/2020"
+
+
+import unittest
+import os
+import sys
+import time
+import tempfile
+
+from .. import retry
+
+
+def _cause_segfault():
+ import ctypes
+
+ i = ctypes.c_char(b"a")
+ j = ctypes.pointer(i)
+ c = 0
+ while True:
+ j[c] = b"a"
+ c += 1
+
+
+def _submain(filename, kwcheck=None, ncausefailure=0, faildelay=0):
+ assert filename
+ assert kwcheck
+ sys.stderr = open(os.devnull, "w")
+
+ with open(filename, mode="r") as f:
+ failcounter = int(f.readline().strip())
+
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ with open(filename, mode="w") as f:
+ f.write(str(failcounter))
+ if failcounter % 2:
+ raise retry.RetryError
+ else:
+ _cause_segfault()
+ return True
+
+
+_wsubmain = retry.retry_in_subprocess()(_submain)
+
+
+class TestRetry(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp()
+ self.ctr_file = os.path.join(self.test_dir, "failcounter.txt")
+
+ def tearDown(self):
+ if os.path.exists(self.ctr_file):
+ os.unlink(self.ctr_file)
+ os.rmdir(self.test_dir)
+
+ def test_retry(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @retry.retry()
+ def method(check, kwcheck=None):
+ assert check
+ assert kwcheck
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise retry.RetryError
+ return True
+
+ failcounter = 0
+ kw = {
+ "kwcheck": True,
+ "retry_timeout": sufficient_timeout,
+ }
+ self.assertTrue(method(True, **kw))
+
+ failcounter = 0
+ kw = {
+ "kwcheck": True,
+ "retry_timeout": insufficient_timeout,
+ }
+ with self.assertRaises(retry.RetryTimeoutError):
+ method(True, **kw)
+
+ def test_retry_contextmanager(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @retry.retry_contextmanager()
+ def context(check, kwcheck=None):
+ assert check
+ assert kwcheck
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise retry.RetryError
+ yield True
+
+ failcounter = 0
+ kw = {"kwcheck": True, "retry_timeout": sufficient_timeout}
+ with context(True, **kw) as result:
+ self.assertTrue(result)
+
+ failcounter = 0
+ kw = {"kwcheck": True, "retry_timeout": insufficient_timeout}
+ with self.assertRaises(retry.RetryTimeoutError):
+ with context(True, **kw) as result:
+ pass
+
+ def test_retry_in_subprocess(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ kw = {
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ "kwcheck": True,
+ "retry_timeout": sufficient_timeout,
+ }
+ with open(self.ctr_file, mode="w") as f:
+ f.write("0")
+ self.assertTrue(_wsubmain(self.ctr_file, **kw))
+
+ kw = {
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ "kwcheck": True,
+ "retry_timeout": insufficient_timeout,
+ }
+ with open(self.ctr_file, mode="w") as f:
+ f.write("0")
+ with self.assertRaises(retry.RetryTimeoutError):
+ _wsubmain(self.ctr_file, **kw)
+
+
+def suite():
+ test_suite = unittest.TestSuite()
+ test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestRetry))
+ return test_suite
+
+
+if __name__ == "__main__":
+ unittest.main(defaultTest="suite")
diff --git a/silx/utils/test/test_testutils.py b/silx/utils/test/test_testutils.py
index c29a703..c72a3d8 100755
--- a/silx/utils/test/test_testutils.py
+++ b/silx/utils/test/test_testutils.py
@@ -84,6 +84,15 @@ class TestTestLogging(unittest.TestCase):
logger.error("aaa")
self.assertIsNotNone(listener)
+ def testErrorMessage(self):
+ logger = logging.getLogger(__name__ + "testCanBreak")
+ listener = testutils.TestLogging(logger, error=1, warning=2)
+ with self.assertRaisesRegex(RuntimeError, "aaabbb"):
+ with listener:
+ logger.error("aaa")
+ logger.warning("aaabbb")
+ logger.error("aaa")
+
def suite():
loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
diff --git a/silx/utils/testutils.py b/silx/utils/testutils.py
index 1252269..434beee 100755
--- a/silx/utils/testutils.py
+++ b/silx/utils/testutils.py
@@ -102,6 +102,17 @@ def parameterize(test_case_class, *args, **kwargs):
return suite
+class LoggingRuntimeError(RuntimeError):
+ """Raised when the `TestLogging` fails"""
+
+ def __init__(self, msg, records):
+ super(LoggingRuntimeError, self).__init__(msg)
+ self.records = records
+
+ def __str__(self):
+ return super(LoggingRuntimeError, self).__str__() + " -> " + str(self.records)
+
+
class TestLogging(logging.Handler):
"""Context checking the number of logging messages from a specified Logger.
@@ -220,8 +231,8 @@ class TestLogging(logging.Handler):
expected_count = expected_count_by_level[level]
message += "%d %s (got %d)" % (expected_count, logging.getLevelName(level), count)
- raise RuntimeError(
- 'Expected %s' % message)
+ raise LoggingRuntimeError(
+ 'Expected %s' % message, records=list(self.records))
def emit(self, record):
"""Override :meth:`logging.Handler.emit`"""