From b3bea947efa55d2c0f198b6c6795b3177be27f45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Picca=20Fr=C3=A9d=C3=A9ric-Emmanuel?= Date: Wed, 6 Jan 2021 14:10:12 +0100 Subject: New upstream version 0.14.0+dfsg --- CHANGELOG.rst | 114 ++ MANIFEST.in | 1 + PKG-INFO | 6 +- README.rst | 4 +- copyright | 4 + doc/source/Tutorials/fit.rst | 6 +- doc/source/applications/view.rst | 3 +- .../modules/gui/data/img/ArrayTableWidget.png | Bin 42383 -> 41977 bytes doc/source/modules/gui/data/img/DataViewer.png | Bin 38670 -> 37500 bytes doc/source/modules/gui/icons.rst | 9 + .../modules/gui/plot/img/BasicGridStatsWidget.png | Bin 12081 -> 20668 bytes doc/source/modules/gui/plot/img/LimitsToolBar.png | Bin 21697 -> 21679 bytes doc/source/modules/gui/plot/img/ROIStatsWidget.png | Bin 0 -> 9659 bytes doc/source/modules/gui/plot/img/logColorbar.png | Bin 11461 -> 12390 bytes doc/source/modules/gui/plot/index.rst | 1 + doc/source/modules/gui/plot/plotwidget.rst | 3 + doc/source/modules/gui/plot/roistatswidget.rst | 24 + doc/source/modules/gui/plot3d/img/SceneWidget.png | Bin 65009 -> 73947 bytes .../modules/gui/widgets/img/FrameBrowser.png | Bin 3186 -> 4538 bytes .../modules/gui/widgets/img/PeriodicCombo.png | Bin 2607 -> 3124 bytes .../modules/gui/widgets/img/PeriodicList.png | Bin 25731 -> 38615 bytes .../modules/gui/widgets/img/PeriodicTable.png | Bin 33564 -> 61338 bytes doc/source/modules/gui/widgets/img/TableWidget.png | Bin 3789 -> 3624 bytes .../gui/widgets/img/ThreadPoolPushButton.png | Bin 2151 -> 2566 bytes .../modules/gui/widgets/img/WaitingPushButton.png | Bin 1068 -> 1187 bytes doc/source/modules/io/specfile.rst | 3 +- doc/source/sample_code/img/plotROIStats.png | Bin 0 -> 167957 bytes doc/source/sample_code/index.rst | 8 + examples/plotInteractiveImageROI.py | 34 +- examples/plotROIStats.py | 341 +++++ examples/plotStats.py | 5 +- package/debian10/control | 2 +- pyproject.toml | 7 + run_tests.py | 16 +- silx.egg-info/PKG-INFO | 6 +- silx.egg-info/SOURCES.txt | 21 +- silx/app/test/test_convert.py | 4 +- silx/app/view/Viewer.py | 2 + silx/app/view/main.py | 6 +- silx/gui/_glutils/FramebufferTexture.py | 3 +- silx/gui/_glutils/OpenGLWidget.py | 14 + silx/gui/_glutils/Texture.py | 319 +++-- silx/gui/_glutils/utils.py | 30 +- silx/gui/colors.py | 117 +- silx/gui/data/DataViews.py | 2 +- silx/gui/data/Hdf5TableView.py | 68 +- silx/gui/data/NXdataWidgets.py | 1 + silx/gui/data/TextFormatter.py | 8 +- silx/gui/data/test/test_dataviewer.py | 8 +- silx/gui/data/test/test_textformatter.py | 28 +- silx/gui/fit/BackgroundWidget.py | 4 +- silx/gui/fit/FitWidget.py | 2 +- silx/gui/hdf5/Hdf5Item.py | 24 +- silx/gui/hdf5/test/test_hdf5.py | 162 ++- silx/gui/plot/ColorBar.py | 5 +- silx/gui/plot/ComplexImageView.py | 2 +- silx/gui/plot/CurvesROIWidget.py | 6 +- silx/gui/plot/ImageStack.py | 25 +- silx/gui/plot/ImageView.py | 12 +- silx/gui/plot/MaskToolsWidget.py | 30 +- silx/gui/plot/PlotInteraction.py | 19 + silx/gui/plot/PlotWidget.py | 186 ++- silx/gui/plot/PlotWindow.py | 107 +- silx/gui/plot/ROIStatsWidget.py | 780 +++++++++++ silx/gui/plot/ScatterMaskToolsWidget.py | 24 +- silx/gui/plot/StackView.py | 66 +- silx/gui/plot/StatsWidget.py | 32 +- silx/gui/plot/_BaseMaskToolsWidget.py | 14 +- silx/gui/plot/_utils/dtime_ticklayout.py | 16 +- silx/gui/plot/actions/control.py | 79 +- silx/gui/plot/actions/io.py | 71 +- silx/gui/plot/backends/BackendBase.py | 25 +- silx/gui/plot/backends/BackendMatplotlib.py | 149 +- silx/gui/plot/backends/BackendOpenGL.py | 426 +++--- silx/gui/plot/backends/glutils/GLPlotCurve.py | 86 +- silx/gui/plot/backends/glutils/GLPlotFrame.py | 159 ++- silx/gui/plot/backends/glutils/GLPlotImage.py | 103 +- silx/gui/plot/backends/glutils/GLPlotItem.py | 94 ++ silx/gui/plot/backends/glutils/GLPlotTriangles.py | 14 +- silx/gui/plot/backends/glutils/GLText.py | 60 +- silx/gui/plot/backends/glutils/GLTexture.py | 5 +- silx/gui/plot/backends/glutils/__init__.py | 3 +- silx/gui/plot/items/__init__.py | 3 +- silx/gui/plot/items/_arc_roi.py | 873 ++++++++++++ silx/gui/plot/items/_pick.py | 2 +- silx/gui/plot/items/_roi_base.py | 835 ++++++++++++ silx/gui/plot/items/complex.py | 15 +- silx/gui/plot/items/core.py | 189 ++- silx/gui/plot/items/curve.py | 23 - silx/gui/plot/items/histogram.py | 35 +- silx/gui/plot/items/image.py | 79 +- silx/gui/plot/items/roi.py | 1438 +------------------- silx/gui/plot/items/scatter.py | 19 +- silx/gui/plot/items/shape.py | 35 +- silx/gui/plot/matplotlib/__init__.py | 50 +- silx/gui/plot/stats/stats.py | 497 +++++-- silx/gui/plot/stats/statshandler.py | 12 +- silx/gui/plot/test/__init__.py | 2 + silx/gui/plot/test/testComplexImageView.py | 8 +- silx/gui/plot/test/testCurvesROIWidget.py | 10 +- silx/gui/plot/test/testItem.py | 90 +- silx/gui/plot/test/testMaskToolsWidget.py | 15 +- silx/gui/plot/test/testPlotInteraction.py | 6 +- silx/gui/plot/test/testPlotWidget.py | 237 ++-- silx/gui/plot/test/testPlotWindow.py | 21 +- silx/gui/plot/test/testRoiStatsWidget.py | 290 ++++ silx/gui/plot/test/testScatterMaskToolsWidget.py | 16 +- silx/gui/plot/test/testStackView.py | 15 +- silx/gui/plot/test/testStats.py | 273 +++- silx/gui/plot/tools/profile/manager.py | 31 +- silx/gui/plot/tools/profile/rois.py | 14 +- silx/gui/plot/tools/roi.py | 239 +++- silx/gui/plot/tools/test/testROI.py | 67 +- silx/gui/plot3d/ScalarFieldView.py | 6 +- silx/gui/plot3d/items/_pick.py | 4 +- silx/gui/plot3d/items/core.py | 54 +- silx/gui/plot3d/items/mixins.py | 1 + silx/gui/plot3d/items/volume.py | 2 +- silx/gui/plot3d/scene/cutplane.py | 4 +- silx/gui/plot3d/scene/function.py | 75 +- silx/gui/plot3d/scene/primitives.py | 10 +- silx/gui/plot3d/scene/text.py | 3 +- silx/gui/plot3d/scene/transform.py | 65 +- silx/gui/plot3d/scene/utils.py | 4 +- silx/gui/plot3d/test/testStatsWidget.py | 3 + silx/gui/test/test_colors.py | 51 +- silx/gui/utils/glutils.py | 7 + silx/gui/utils/matplotlib.py | 71 + silx/gui/utils/signal.py | 141 ++ silx/gui/utils/testutils.py | 2 - silx/gui/widgets/ElidedLabel.py | 4 +- silx/gui/widgets/test/__init__.py | 4 +- silx/gui/widgets/test/test_legendiconwidget.py | 74 + silx/image/marchingsquares/_mergeimpl.pyx | 4 +- silx/image/tomography.py | 2 + silx/io/commonh5.py | 22 +- silx/io/dictdump.py | 421 ++++-- silx/io/fabioh5.py | 10 +- silx/io/nxdata/parse.py | 4 +- silx/io/setup.py | 2 +- silx/io/specfile/src/locale_management.c | 5 +- silx/io/test/test_dictdump.py | 257 +++- silx/io/test/test_spectoh5.py | 3 +- silx/io/test/test_url.py | 10 + silx/io/test/test_utils.py | 244 +++- silx/io/url.py | 21 +- silx/io/utils.py | 331 ++++- silx/math/colormap.pyx | 24 +- silx/math/fft/test/test_fft.py | 8 +- silx/math/fit/bgtheories.py | 8 +- silx/math/fit/fitmanager.py | 16 +- silx/math/fit/fittheories.py | 34 +- silx/math/fit/functions.pyx | 4 +- silx/math/fit/leastsq.py | 30 +- silx/math/fit/test/test_fit.py | 8 +- silx/math/fit/test/test_fitmanager.py | 12 +- silx/opencl/backprojection.py | 33 +- silx/opencl/common.py | 90 +- silx/opencl/convolution.py | 11 +- silx/opencl/processing.py | 54 +- silx/opencl/projection.py | 33 +- silx/opencl/test/test_addition.py | 28 +- silx/opencl/test/test_backprojection.py | 3 +- silx/opencl/test/test_convolution.py | 99 +- silx/resources/gui/icons/add.png | Bin 0 -> 470 bytes silx/resources/gui/icons/add.svg | 2 + silx/resources/gui/icons/backend-opengl.png | Bin 0 -> 1582 bytes silx/resources/gui/icons/backend-opengl.svg | 18 + silx/resources/gui/icons/rm.png | Bin 0 -> 348 bytes silx/resources/gui/icons/rm.svg | 2 + silx/resources/opencl/backproj.cl | 301 +--- silx/resources/opencl/proj.cl | 4 +- silx/sx/_plot.py | 4 +- silx/utils/_have_openmp.pxd | 49 + silx/utils/_have_openmp.pxi | 49 - version.py | 27 +- 176 files changed, 8701 insertions(+), 3628 deletions(-) create mode 100644 doc/source/modules/gui/plot/img/ROIStatsWidget.png create mode 100644 doc/source/modules/gui/plot/roistatswidget.rst create mode 100644 doc/source/sample_code/img/plotROIStats.png create mode 100644 examples/plotROIStats.py create mode 100644 pyproject.toml create mode 100644 silx/gui/plot/ROIStatsWidget.py create mode 100644 silx/gui/plot/backends/glutils/GLPlotItem.py create mode 100644 silx/gui/plot/items/_arc_roi.py create mode 100644 silx/gui/plot/items/_roi_base.py create mode 100644 silx/gui/plot/test/testRoiStatsWidget.py create mode 100644 silx/gui/utils/matplotlib.py create mode 100644 silx/gui/utils/signal.py create mode 100644 silx/gui/widgets/test/test_legendiconwidget.py create mode 100644 silx/resources/gui/icons/add.png create mode 100644 silx/resources/gui/icons/add.svg create mode 100644 silx/resources/gui/icons/backend-opengl.png create mode 100644 silx/resources/gui/icons/backend-opengl.svg create mode 100644 silx/resources/gui/icons/rm.png create mode 100644 silx/resources/gui/icons/rm.svg create mode 100644 silx/utils/_have_openmp.pxd delete mode 100644 silx/utils/_have_openmp.pxi diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e194827..8370a32 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,120 @@ Change Log ========== +0.14.0: 2020/12/11 +------------------ + +This is the first version of `silx` supporting `h5py` >= v3.0. + +This is the last version of `silx` officially supporting Python 3.5. + +* `silx.gui`: + + * Added support for HDF5 external data (virtual and raw) (PR #3222) + * Added lazy update handling of OpenGL textures (PR #3205) + * Deprecated `silx.gui.plot.matplotlib` module (use `silx.gui.utils.matplotlib` instead) (PR #3158) + * Improved memory allocation by using already defined `fontMetrics` instread of creating a new one (PR #3239) + * Make `TextFormatter` compatible with `h5py`>=3 (PR #3253) + * Fixed `matplotlib` 3.3.0rc1 deprecation warnings (PR #3145) + + * `silx.gui.colors.Colormap`: + + * Added `Colormap.get|setNaNColor` to change color used for NaN, fix different NaN displays for matplotlib/openGL backends (PR #3143) + * Refactored PlotWidget OpenGL backend to enable extensions (PR #3147) + * Fixed use of `QThreadPool.tryTake` to be Qt5.7 compliant (PR #3250) + + * `silx.gui.plot`: + + * Added the feature to compute statistics inside a specific region of interest (PR #3056) + * Added an action to switch on/off OpenGL rendering on a plot (PR #3261) + * Added test for ROI interaction mode (PR #3283) + * Added saving of error bars when saving a plot (PR #3199) + * Added `ImageStack.clear` (PR #3167) + * Improved image profile tool to support `PlotWidget` item extension (PR #3150) + * Improved `Stackview`: replaced `setColormap` `autoscale` argument by `scaleColormapRangeToStack` method (PR #3279) + * Updated `3 stddev` autoscale algorithm, clamp it with the minmax data in order to improve the contrast (PR #3284) + * Updated ROI module: splitted into 3 modules base/common/arc_roi (PR #3283) + * Fixed `ColormapDialog` custom range input (PR #3153) + * Fixed issue when changing ROI mode while a ROI is being created (PR #3186) + * Fixed `RegionOfInterest` refresh when highlighted (PR #3197) + * Fixed arc roi shape: make sure start and end points are part of the shape (PR #3257) + * Fixed issue in `Colormap` `3 stdev` autoscale mode and avoided warnings (PR #3295) + + * Major improvements of `PlotWidget`: + + * Added `get|setAxesMargins` methods to control margin ratios around plot area (PR #3196) + * Added `PlotWidget.[get|set]Backend` enabling switching backend (PR #3255) + * Added multi interaction mode for ROIs (can be switched with a single click on an handle, or the context menu) (PR #3260) + * Added polar interaction mode for arc ROI (PR #3260) + * Added `PlotWidget.sigDefaultContextMenu` to allow to feed the default context menu (PR #3260) + * Added context menu to the selected ROI to remove it (PR #3260) + * Added pan interaction to ROI authoring (`select-draw`) interaction mode (PR #3291) + * Added support of right axis label with OpenGL backend (PR #3293) + * Added item visible bounds feature to PlotWidget items (PR #3223) + * Added a `DataItem` base class for items having a "data extent" in the plot (PR #3212) + * Added support for float16 texture in OpenGL backend (PR #3194) + * Improved support of high-DPI screen in OpenGL backend (PR #3203) + * Updated: Use points rather than pixels for marker size and line width with OpenGL backend (PR #3203) + * Updated: Expose `PlotWidget` colors as Qt properties (PR #3269) + * Fixed time serie axis for range < 2.5 microseconds (PR #3195) + * Fixed initial size of OpenGL backend (PR #3209) + * Fixed `PlotWidget` image items displayed below the grid by default (PR #3235) + * Fixed OpenGL backend image display with sqrt colormap normalization (PR #3248) + * Fixed support of shapes with multiple polygons in the OpenGL backend (PR #3259) + * Fixes duplicated callback on ROIs (there was one for each ROI managed created on the plot) (PR #3260) + * Fixed RegionOfInterest `contains` methods (PR #3336) + + * `silx.gui.colors.plot3d`: + + * Improved scene rendering (PR #3149) + * Fixed handling of transparency of cut plane (PR #3204) + +* `silx.image`: + + * Fixed slow `image.tomography.get_next_power()` (PR #3168) + +* `silx.io`: + + * Added support for HDF5 link preservation in `dictdump` (PR #3224) + * Added support for numpy arrays of `numbers` (PR #3251) + * Make `h5todict` resilient to issues in the HDF5 file (PR #3162) + +* `silx.math`: + + * Improved colormap performances for small datasets (PR #3282) + +* `silx.opencl`: + + * Added textures availability check (PR #3273) + * Added a warning when there is an issue in the Ocl destruction (PR #3280) + * Fixed Sift test on modern GPU (PR #3262) + +* Miscellaneous: + + * Added HDF5 strings: handle `h5py` 2.x and 3.x (PR #3240) + * Fixed `cython` 3 compatibility and deprecation warning (PR #3164, #3189) + + +0.13.2: 2020/09/15 +------------------ + +Minor release: + +* silx view application: Prevent collapsing browsing panel, Added `-f` command line option (PR #3176) + +* `silx.gui`: + + * `silx.gui.data`: Fixed `DataViews.titleForSelection` method (PR #3171). + * `silx.gui.plot.items`: Added `DATA_BOUNDS` visualization parameter for `Scatter` item histogram bounds (PR #3180) + * `silx.gui.plot.PlotWidget`: Fixed support of curves with infinite data (PR #3175) + * `silx.gui.utils.glutils`: Fixed `isOpenGLAvailable` function (PR #3184) + +* Documentation: + + * Update silx view command line options documentation (PR #3173) + * Update version number and changelog (PR #3190) + + 0.13.1: 2020/07/22 ------------------ diff --git a/MANIFEST.in b/MANIFEST.in index abdc1f9..da024c2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -9,6 +9,7 @@ include stdeb.cfg include build-deb.sh include requirements.txt include requirements-dev.txt +include pyproject.toml recursive-include silx *.pyx *.pxd *.pxi recursive-include silx *.h *.c *.hpp *.cpp recursive-include doc/source *.py *.rst *.png *.ico *.ipynb diff --git a/PKG-INFO b/PKG-INFO index 74f97d5..6bf8a6f 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: silx -Version: 0.13.1 +Version: 0.14.0 Summary: Software library for X-ray data analysis Home-page: http://www.silx.org/ Author: data analysis unit @@ -108,8 +108,8 @@ Description: *silx* releases can be cited via their DOI on Zenodo: |zenodo DOI| - .. |Travis Status| image:: https://travis-ci.org/silx-kit/silx.svg?branch=master - :target: https://travis-ci.org/silx-kit/silx?branch=master + .. |Travis Status| image:: https://travis-ci.com/silx-kit/silx.svg?branch=master + :target: https://travis-ci.com/silx-kit/silx .. |Appveyor Status| image:: https://ci.appveyor.com/api/projects/status/qgox9ei0wxwfagrb/branch/master?svg=true :target: https://ci.appveyor.com/project/ESRF/silx?branch=master .. |zenodo DOI| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.591709.svg diff --git a/README.rst b/README.rst index 6b30551..20842ae 100644 --- a/README.rst +++ b/README.rst @@ -100,8 +100,8 @@ Citation *silx* releases can be cited via their DOI on Zenodo: |zenodo DOI| -.. |Travis Status| image:: https://travis-ci.org/silx-kit/silx.svg?branch=master - :target: https://travis-ci.org/silx-kit/silx?branch=master +.. |Travis Status| image:: https://travis-ci.com/silx-kit/silx.svg?branch=master + :target: https://travis-ci.com/silx-kit/silx .. |Appveyor Status| image:: https://ci.appveyor.com/api/projects/status/qgox9ei0wxwfagrb/branch/master?svg=true :target: https://ci.appveyor.com/project/ESRF/silx?branch=master .. |zenodo DOI| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.591709.svg diff --git a/copyright b/copyright index 174de4b..25937ce 100644 --- a/copyright +++ b/copyright @@ -24,6 +24,10 @@ Files: silx/third_party/modest_image.py Copyright: 2013 Chris Beaumont License: MIT +Files: silx/gui/utils/signal.py +Copyright: 2012 University of North Carolina at Chapel Hill, Luke Campagnola +License: MIT + License: MIT Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), diff --git a/doc/source/Tutorials/fit.rst b/doc/source/Tutorials/fit.rst index b1b28e5..d9671f4 100644 --- a/doc/source/Tutorials/fit.rst +++ b/doc/source/Tutorials/fit.rst @@ -234,7 +234,7 @@ the previous tutorial (See `Weighted fit`_) from silx.math.fit.fitmanager import FitManager # Create synthetic data with a sum of gaussian functions - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) y = 2.4 * x**4 - 10. * x**3 + 15.2 * x**2 - 24.6 * x + 150. # define our fit function: a generic polynomial of degree 4 @@ -304,7 +304,7 @@ data, generated using another *silx* module: :mod:`silx.math.fit.functions`. from silx.math.fit.fitmanager import FitManager # Create synthetic data with a sum of gaussian functions - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) # height, center x, fwhm p = [1000, 100., 250, # 1st peak @@ -526,7 +526,7 @@ Simple usage from silx.gui.fit import FitWidget from silx.math.fit.functions import sum_gauss - x = numpy.arange(2000).astype(numpy.float) + x = numpy.arange(2000).astype(numpy.float64) constant_bg = 3.14 # gaussian parameters: height, position, fwhm diff --git a/doc/source/applications/view.rst b/doc/source/applications/view.rst index 694f95d..d4145c2 100644 --- a/doc/source/applications/view.rst +++ b/doc/source/applications/view.rst @@ -48,7 +48,8 @@ Options -h, --help Show this help message and exit --debug Set logging system in debug mode --use-opengl-plot Use OpenGL for plots (instead of matplotlib) - + -f, --fresh Start the application using new fresh user preferences + --hdf5-file-locking Start the application with HDF5 file locking enabled (it is disabled by default) Examples of usage ----------------- diff --git a/doc/source/modules/gui/data/img/ArrayTableWidget.png b/doc/source/modules/gui/data/img/ArrayTableWidget.png index c879427..e7bb2a9 100644 Binary files a/doc/source/modules/gui/data/img/ArrayTableWidget.png and b/doc/source/modules/gui/data/img/ArrayTableWidget.png differ diff --git a/doc/source/modules/gui/data/img/DataViewer.png b/doc/source/modules/gui/data/img/DataViewer.png index a1fabb9..7980de9 100644 Binary files a/doc/source/modules/gui/data/img/DataViewer.png and b/doc/source/modules/gui/data/img/DataViewer.png differ diff --git a/doc/source/modules/gui/icons.rst b/doc/source/modules/gui/icons.rst index 8a939ea..67235c2 100644 --- a/doc/source/modules/gui/icons.rst +++ b/doc/source/modules/gui/icons.rst @@ -53,10 +53,14 @@ Available icons - add-shape-unknown * - |add-shape-vertical| - add-shape-vertical + * - |add| + - add * - |arrow-keys| - arrow-keys * - |axis| - axis + * - |backend-opengl| + - backend-opengl * - |camera| - camera * - |clipboard| @@ -303,6 +307,8 @@ Available icons - profile2D * - |remove| - remove + * - |rm| + - rm * - |rotate-3d| - rotate-3d * - |rudder| @@ -413,8 +419,10 @@ Available icons .. |add-shape-rectangle| image:: ../../../../silx/resources/gui/icons/add-shape-rectangle.png .. |add-shape-unknown| image:: ../../../../silx/resources/gui/icons/add-shape-unknown.png .. |add-shape-vertical| image:: ../../../../silx/resources/gui/icons/add-shape-vertical.png +.. |add| image:: ../../../../silx/resources/gui/icons/add.png .. |arrow-keys| image:: ../../../../silx/resources/gui/icons/arrow-keys.png .. |axis| image:: ../../../../silx/resources/gui/icons/axis.png +.. |backend-opengl| image:: ../../../../silx/resources/gui/icons/backend-opengl.png .. |camera| image:: ../../../../silx/resources/gui/icons/camera.png .. |clipboard| image:: ../../../../silx/resources/gui/icons/clipboard.png .. |close| image:: ../../../../silx/resources/gui/icons/close.png @@ -538,6 +546,7 @@ Available icons .. |profile1D| image:: ../../../../silx/resources/gui/icons/profile1D.png .. |profile2D| image:: ../../../../silx/resources/gui/icons/profile2D.png .. |remove| image:: ../../../../silx/resources/gui/icons/remove.png +.. |rm| image:: ../../../../silx/resources/gui/icons/rm.png .. |rotate-3d| image:: ../../../../silx/resources/gui/icons/rotate-3d.png .. |rudder| image:: ../../../../silx/resources/gui/icons/rudder.png .. |selected| image:: ../../../../silx/resources/gui/icons/selected.png diff --git a/doc/source/modules/gui/plot/img/BasicGridStatsWidget.png b/doc/source/modules/gui/plot/img/BasicGridStatsWidget.png index bc675f0..261909a 100644 Binary files a/doc/source/modules/gui/plot/img/BasicGridStatsWidget.png and b/doc/source/modules/gui/plot/img/BasicGridStatsWidget.png differ diff --git a/doc/source/modules/gui/plot/img/LimitsToolBar.png b/doc/source/modules/gui/plot/img/LimitsToolBar.png index ede66c8..b360fe0 100644 Binary files a/doc/source/modules/gui/plot/img/LimitsToolBar.png and b/doc/source/modules/gui/plot/img/LimitsToolBar.png differ diff --git a/doc/source/modules/gui/plot/img/ROIStatsWidget.png b/doc/source/modules/gui/plot/img/ROIStatsWidget.png new file mode 100644 index 0000000..7a634fe Binary files /dev/null and b/doc/source/modules/gui/plot/img/ROIStatsWidget.png differ diff --git a/doc/source/modules/gui/plot/img/logColorbar.png b/doc/source/modules/gui/plot/img/logColorbar.png index cde1ad9..49282e7 100644 Binary files a/doc/source/modules/gui/plot/img/logColorbar.png and b/doc/source/modules/gui/plot/img/logColorbar.png differ diff --git a/doc/source/modules/gui/plot/index.rst b/doc/source/modules/gui/plot/index.rst index b6c2000..7f60ba4 100644 --- a/doc/source/modules/gui/plot/index.rst +++ b/doc/source/modules/gui/plot/index.rst @@ -59,6 +59,7 @@ Additionnal plot tool widgets: roi.rst printpreviewtoolbutton.rst statswidget.rst + roistatswidget.rst stats/index.rst Utilities diff --git a/doc/source/modules/gui/plot/plotwidget.rst b/doc/source/modules/gui/plot/plotwidget.rst index 9978479..d16c4ab 100644 --- a/doc/source/modules/gui/plot/plotwidget.rst +++ b/doc/source/modules/gui/plot/plotwidget.rst @@ -82,7 +82,10 @@ The following methods handle plot limits, aspect ratio, grid and axes display: .. automethod:: PlotWidget.setKeepDataAspectRatio .. automethod:: PlotWidget.getGraphGrid .. automethod:: PlotWidget.setGraphGrid +.. automethod:: PlotWidget.isAxesDisplayed .. automethod:: PlotWidget.setAxesDisplayed +.. automethod:: PlotWidget.getAxesMargins +.. automethod:: PlotWidget.setAxesMargins Reset zoom .......... diff --git a/doc/source/modules/gui/plot/roistatswidget.rst b/doc/source/modules/gui/plot/roistatswidget.rst new file mode 100644 index 0000000..d9563b5 --- /dev/null +++ b/doc/source/modules/gui/plot/roistatswidget.rst @@ -0,0 +1,24 @@ + +.. currentmodule:: silx.gui.plot.ROIStatsWidget + +:mod:`ROIStatsWidget`: Display a set of statistics for couples (plot items, roi) +================================================================================ + +An example of the usage is given in examples/plotRoiStats.py + +.. automodule:: silx.gui.plot.ROIStatsWidget + + +:class:`ROIStatsWidget` class +----------------------------- + +.. autoclass:: ROIStatsWidget + :show-inheritance: + :members: + +:class:`ROIStatsItemHelper` class +--------------------------------- + +.. autoclass:: ROIStatsItemHelper + :show-inheritance: + :members: diff --git a/doc/source/modules/gui/plot3d/img/SceneWidget.png b/doc/source/modules/gui/plot3d/img/SceneWidget.png index dbe7791..4ddc0a8 100644 Binary files a/doc/source/modules/gui/plot3d/img/SceneWidget.png and b/doc/source/modules/gui/plot3d/img/SceneWidget.png differ diff --git a/doc/source/modules/gui/widgets/img/FrameBrowser.png b/doc/source/modules/gui/widgets/img/FrameBrowser.png index 17b355a..1d4ebcf 100644 Binary files a/doc/source/modules/gui/widgets/img/FrameBrowser.png and b/doc/source/modules/gui/widgets/img/FrameBrowser.png differ diff --git a/doc/source/modules/gui/widgets/img/PeriodicCombo.png b/doc/source/modules/gui/widgets/img/PeriodicCombo.png index 644e502..7534805 100644 Binary files a/doc/source/modules/gui/widgets/img/PeriodicCombo.png and b/doc/source/modules/gui/widgets/img/PeriodicCombo.png differ diff --git a/doc/source/modules/gui/widgets/img/PeriodicList.png b/doc/source/modules/gui/widgets/img/PeriodicList.png index 5ec741f..74ce7d6 100644 Binary files a/doc/source/modules/gui/widgets/img/PeriodicList.png and b/doc/source/modules/gui/widgets/img/PeriodicList.png differ diff --git a/doc/source/modules/gui/widgets/img/PeriodicTable.png b/doc/source/modules/gui/widgets/img/PeriodicTable.png index a521bd7..bada39a 100644 Binary files a/doc/source/modules/gui/widgets/img/PeriodicTable.png and b/doc/source/modules/gui/widgets/img/PeriodicTable.png differ diff --git a/doc/source/modules/gui/widgets/img/TableWidget.png b/doc/source/modules/gui/widgets/img/TableWidget.png index de78687..a614ae7 100644 Binary files a/doc/source/modules/gui/widgets/img/TableWidget.png and b/doc/source/modules/gui/widgets/img/TableWidget.png differ diff --git a/doc/source/modules/gui/widgets/img/ThreadPoolPushButton.png b/doc/source/modules/gui/widgets/img/ThreadPoolPushButton.png index 5bdebee..eb55b14 100644 Binary files a/doc/source/modules/gui/widgets/img/ThreadPoolPushButton.png and b/doc/source/modules/gui/widgets/img/ThreadPoolPushButton.png differ diff --git a/doc/source/modules/gui/widgets/img/WaitingPushButton.png b/doc/source/modules/gui/widgets/img/WaitingPushButton.png index 9bab0fa..97bd14a 100644 Binary files a/doc/source/modules/gui/widgets/img/WaitingPushButton.png and b/doc/source/modules/gui/widgets/img/WaitingPushButton.png differ diff --git a/doc/source/modules/io/specfile.rst b/doc/source/modules/io/specfile.rst index a937ca8..9b26e31 100644 --- a/doc/source/modules/io/specfile.rst +++ b/doc/source/modules/io/specfile.rst @@ -5,8 +5,7 @@ ---------------------------------- .. automodule:: silx.io.specfile - :members: - :undoc-members: + .. autoclass:: silx.io.specfile.SpecFile :members: diff --git a/doc/source/sample_code/img/plotROIStats.png b/doc/source/sample_code/img/plotROIStats.png new file mode 100644 index 0000000..18446aa Binary files /dev/null and b/doc/source/sample_code/img/plotROIStats.png differ diff --git a/doc/source/sample_code/index.rst b/doc/source/sample_code/index.rst index 15bd4c7..0aade4c 100644 --- a/doc/source/sample_code/index.rst +++ b/doc/source/sample_code/index.rst @@ -283,6 +283,14 @@ Sample code that adds specific tools or functions to :class:`~silx.gui.plot.Plot .. note:: for now the possible types manged by the Stats are ('curve', 'image', 'scatter' and 'histogram') + * - :download:`plotROIStats.py <../../../examples/plotROIStats.py>` + - .. image:: img/plotROIStats.png + :width: 150px + - This script is a simple example of how to display statistics on a specific + region of interest. + + An example on how to define your own statistic is given in the 'plotStats.py' + script. * - :download:`plotProfile.py <../../../examples/plotProfile.py>` - .. image:: img/plotProfile.png :width: 150px diff --git a/examples/plotInteractiveImageROI.py b/examples/plotInteractiveImageROI.py index c10bbf3..7254b7e 100644 --- a/examples/plotInteractiveImageROI.py +++ b/examples/plotInteractiveImageROI.py @@ -38,8 +38,10 @@ from silx.gui import qt from silx.gui.plot import Plot2D from silx.gui.plot.tools.roi import RegionOfInterestManager from silx.gui.plot.tools.roi import RegionOfInterestTableWidget +from silx.gui.plot.tools.roi import RoiModeSelectorAction from silx.gui.plot.items.roi import RectangleROI from silx.gui.plot.items import LineMixIn, SymbolMixIn +from silx.gui.plot.actions import control as control_actions def dummy_image(): @@ -54,16 +56,16 @@ def dummy_image(): app = qt.QApplication([]) # Start QApplication -backend = "matplotlib" -if "--opengl" in sys.argv: - backend = "opengl" - # Create the plot widget and add an image -plot = Plot2D(backend=backend) +plot = Plot2D() plot.getDefaultColormap().setName('viridis') plot.setKeepDataAspectRatio(True) plot.addImage(dummy_image()) +toolbar = qt.QToolBar() +toolbar.addAction(control_actions.OpenGLAction(parent=toolbar, plot=plot)) +plot.addToolBar(toolbar) + # Create the object controlling the ROIs and set it up roiManager = RegionOfInterestManager(plot) roiManager.setColor('pink') # Set the color of ROI @@ -105,11 +107,33 @@ for roiClass in roiManager.getSupportedRoiClasses(): action = roiManager.getInteractionModeAction(roiClass) roiToolbar.addAction(action) +class AutoHideToolBar(qt.QToolBar): + """A toolbar which hide itself if no actions are visible""" + + def actionEvent(self, event): + if event.type() == qt.QEvent.ActionChanged: + self._updateVisibility() + return qt.QToolBar.actionEvent(self, event) + + def _updateVisibility(self): + visible = False + for action in self.actions(): + if action.isVisible(): + visible = True + break + self.setVisible(visible) + +roiToolbarEdit = AutoHideToolBar() +modeSelectorAction = RoiModeSelectorAction() +modeSelectorAction.setRoiManager(roiManager) +roiToolbarEdit.addAction(modeSelectorAction) + # Add the region of interest table and the buttons to a dock widget widget = qt.QWidget() layout = qt.QVBoxLayout() widget.setLayout(layout) layout.addWidget(roiToolbar) +layout.addWidget(roiToolbarEdit) layout.addWidget(roiTable) diff --git a/examples/plotROIStats.py b/examples/plotROIStats.py new file mode 100644 index 0000000..3caff7e --- /dev/null +++ b/examples/plotROIStats.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This script is a simple example of how to display statistics on a specific +region of interest. + +An example on how to define your own statistic is given in the 'plotStats.py' +script. +""" + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "23/07/2019" + +from silx.gui import qt +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.tools.roi import RegionOfInterestTableWidget +from silx.gui.plot.items.roi import RectangleROI, PolygonROI, ArcROI +from silx.gui.plot import Plot2D +from silx.gui.plot.CurvesROIWidget import ROI +from silx.gui.plot.ROIStatsWidget import ROIStatsWidget +from silx.gui.plot.StatsWidget import UpdateModeWidget +import sys +import argparse +import functools +import numpy +import threading +from silx.gui.utils import concurrent +import random +import time + + +class UpdateThread(threading.Thread): + """Thread updating the image of a :class:`~silx.gui.plot.Plot2D` + + :param plot2d: The Plot2D to update.""" + + def __init__(self, plot2d): + self.plot2d = plot2d + self.running = False + super(UpdateThread, self).__init__() + + def start(self): + """Start the update thread""" + self.running = True + super(UpdateThread, self).start() + + def run(self): + """Method implementing thread loop that updates the plot""" + while self.running: + time.sleep(1) + # Run plot update asynchronously + concurrent.submitToQtMainThread( + self.plot2d.addImage, + numpy.random.random(10000).reshape(100, 100), + resetzoom=False, + legend=random.choice(('img1', 'img2')) + ) + + def stop(self): + """Stop the update thread""" + self.running = False + self.join(2) + + +class _RoiStatsWidget(qt.QMainWindow): + """ + Window used to associate ROIStatsWidget and UpdateModeWidget + """ + def __init__(self, parent=None, plot=None, mode=None): + assert plot is not None + qt.QMainWindow.__init__(self, parent) + self._roiStatsWindow = ROIStatsWidget(plot=plot) + self.setCentralWidget(self._roiStatsWindow) + + # update mode docker + self._updateModeControl = UpdateModeWidget(parent=self) + self._docker = qt.QDockWidget(parent=self) + self._docker.setWidget(self._updateModeControl) + self.addDockWidget(qt.Qt.TopDockWidgetArea, + self._docker) + self.setWindowFlags(qt.Qt.Widget) + + # connect signal / slot + self._updateModeControl.sigUpdateModeChanged.connect( + self._roiStatsWindow._setUpdateMode) + callback = functools.partial(self._roiStatsWindow._updateAllStats, + is_request=True) + self._updateModeControl.sigUpdateRequested.connect(callback) + + # expose API + self.registerROI = self._roiStatsWindow.registerROI + self.setStats = self._roiStatsWindow.setStats + self.addItem = self._roiStatsWindow.addItem + self.removeItem = self._roiStatsWindow.removeItem + self.setUpdateMode = self._updateModeControl.setUpdateMode + + # setup + self._updateModeControl.setUpdateMode('auto') + + +class _RoiStatsDisplayExWindow(qt.QMainWindow): + """ + Simple window to group the different statistics actors + """ + def __init__(self, parent=None, mode=None): + qt.QMainWindow.__init__(self, parent) + self.plot = Plot2D() + self.setCentralWidget(self.plot) + + # 1D roi management + self._curveRoiWidget = self.plot.getCurvesRoiDockWidget().widget() + # hide last columns which are of no use now + for index in (5, 6, 7, 8): + self._curveRoiWidget.roiTable.setColumnHidden(index, True) + + # 2D - 3D roi manager + self._regionManager = RegionOfInterestManager(parent=self.plot) + + # Create the table widget displaying + self._2DRoiWidget = RegionOfInterestTableWidget() + self._2DRoiWidget.setRegionOfInterestManager(self._regionManager) + + # tabWidget for displaying the rois + self._roisTabWidget = qt.QTabWidget(parent=self) + if hasattr(self._roisTabWidget, 'setTabBarAutoHide'): + self._roisTabWidget.setTabBarAutoHide(True) + + # widget for displaying stats results and update mode + self._statsWidget = _RoiStatsWidget(parent=self, plot=self.plot) + + # create Dock widgets + self._roisTabWidgetDockWidget = qt.QDockWidget(parent=self) + self._roisTabWidgetDockWidget.setWidget(self._roisTabWidget) + self.addDockWidget(qt.Qt.RightDockWidgetArea, + self._roisTabWidgetDockWidget) + + # create Dock widgets + self._roiStatsWindowDockWidget = qt.QDockWidget(parent=self) + self._roiStatsWindowDockWidget.setWidget(self._statsWidget) + # move the docker contain in the parent widget + self.addDockWidget(qt.Qt.RightDockWidgetArea, + self._statsWidget._docker) + self.addDockWidget(qt.Qt.RightDockWidgetArea, + self._roiStatsWindowDockWidget) + + # expose API + self.setUpdateMode = self._statsWidget.setUpdateMode + + def setRois(self, rois1D=None, rois2D=None): + rois1D = rois1D or () + rois2D = rois2D or () + self._curveRoiWidget.setRois(rois1D) + for roi1D in rois1D: + self._statsWidget.registerROI(roi1D) + + for roi2D in rois2D: + self._regionManager.addRoi(roi2D) + self._statsWidget.registerROI(roi2D) + + # update manage tab visibility + if len(rois2D) > 0: + self._roisTabWidget.addTab(self._2DRoiWidget, '2D roi(s)') + if len(rois1D) > 0: + self._roisTabWidget.addTab(self._curveRoiWidget, '1D roi(s)') + + def setStats(self, stats): + self._statsWidget.setStats(stats=stats) + + def addItem(self, item, roi): + self._statsWidget.addItem(roi=roi, plotItem=item) + + +# define stats to display +STATS = [ + ('sum', numpy.sum), + ('mean', numpy.mean), +] + + +def get_1D_rois(): + """return some ROI instance""" + roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy') + roi2D = ROI(name='range2', fromdata=-2, todata=6, type_='energy') + return roi1D, roi2D + + +def get_2D_rois(): + """return some RectangleROI instance""" + rectangle_roi = RectangleROI() + rectangle_roi.setGeometry(origin=(0, 100), size=(20, 20)) + rectangle_roi.setName('Initial ROI') + polygon_roi = PolygonROI() + polygon_points = numpy.array([(0, 10), (10, 20), (45, 30), (35, 0)]) + polygon_roi.setPoints(polygon_points) + polygon_roi.setName('polygon ROI') + arc_roi = ArcROI() + arc_roi.setName('arc ROI') + arc_roi.setFirstShapePoints(numpy.array([[50, 10], [80, 120]])) + arc_roi.setGeometry(*arc_roi.getGeometry()) + return rectangle_roi, polygon_roi, arc_roi + + +def example_curve(mode): + """set up the roi stats example for curves""" + app = qt.QApplication([]) + roi_1, roi_2 = get_1D_rois() + window = _RoiStatsDisplayExWindow() + window.setRois(rois1D=(roi_2, roi_1)) + + # define some image and curve + window.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56), + legend='curve1', color='blue') + window.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.random.random_sample(size=56), + legend='curve2', color='red') + + window.setStats(STATS) + + # add some couple (plotItem, roi) to be displayed by default + curve1_item = window.plot.getCurve('curve1') + window.addItem(item=curve1_item, roi=roi_1) + window.addItem(item=curve1_item, roi=roi_2) + curve2_item = window.plot.getCurve('curve2') + window.addItem(item=curve2_item, roi=roi_2) + + window.setUpdateMode(mode) + + window.show() + app.exec_() + + +def example_image(mode): + """set up the roi stats example for images""" + app = qt.QApplication([]) + rectangle_roi, polygon_roi, arc_roi = get_2D_rois() + + window = _RoiStatsDisplayExWindow() + window.setRois(rois2D=(rectangle_roi, polygon_roi, arc_roi)) + # Create the thread that calls submitToQtMainThread + updateThread = UpdateThread(window.plot) + updateThread.start() # Start updating the plot + + # define some image and curve + window.plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img1') + window.plot.addImage(numpy.random.random(10000).reshape(100, 100), legend='img2', + origin=(0, 100)) + window.setStats(STATS) + + # add some couple (plotItem, roi) to be displayed by default + img1_item = window.plot.getImage('img1') + img2_item = window.plot.getImage('img2') + window.addItem(item=img2_item, roi=rectangle_roi) + window.addItem(item=img1_item, roi=polygon_roi) + window.addItem(item=img1_item, roi=arc_roi) + + window.setUpdateMode(mode) + + window.show() + app.exec_() + updateThread.stop() # Stop updating the plot + + +def example_curve_image(mode): + """set up the roi stats example for curves and images""" + app = qt.QApplication([]) + + roi1D_1, roi1D_2 = get_1D_rois() + rectangle_roi, polygon_roi, arc_roi = get_2D_rois() + + window = _RoiStatsDisplayExWindow() + window.setRois(rois1D=(roi1D_1, roi1D_2,), + rois2D=(rectangle_roi, polygon_roi, arc_roi)) + + # define some image and curve + window.plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img1') + window.plot.addImage(numpy.random.random(10000).reshape(100, 100), + legend='img2', origin=(0, 100)) + window.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56), + legend='curve1') + window.setStats(STATS) + + # add some couple (plotItem, roi) to be displayed by default + img_item = window.plot.getImage('img2') + window.addItem(item=img_item, roi=rectangle_roi) + curve_item = window.plot.getCurve('curve1') + window.addItem(item=curve_item, roi=roi1D_1) + + window.setUpdateMode(mode) + + # Create the thread that calls submitToQtMainThread + updateThread = UpdateThread(window.plot) + updateThread.start() # Start updating the plot + + window.show() + app.exec_() + updateThread.stop() # Stop updating the plot + + +def main(argv): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--items", dest="items", default='curves+images', + help="items type(s), can be curve, image, curves+images") + parser.add_argument('--mode', dest='mode', default='auto', + help='valid modes are `auto` or `manual`') + options = parser.parse_args(argv[1:]) + + items = options.items.lower() + if items == 'curves': + example_curve(mode=options.mode) + elif items == 'images': + example_image(mode=options.mode) + elif items == 'curves+images': + example_curve_image(mode=options.mode) + else: + raise ValueError('invalid entry for item type') + + +if __name__ == '__main__': + main(sys.argv) diff --git a/examples/plotStats.py b/examples/plotStats.py index 5f6e768..030caf8 100644 --- a/examples/plotStats.py +++ b/examples/plotStats.py @@ -33,13 +33,12 @@ On this example we will: - compute curve integrals (only for 'curve'). - compute center of mass for all possible items -.. note:: for now the possible types manged by the Stats are ('curve', 'image', - 'scatter' and 'histogram') +.. note:: stats are available for 1D and 2D at the time being """ __authors__ = ["H. Payno"] __license__ = "MIT" -__date__ = "24/07/2018" +__date__ = "23/07/2019" from silx.gui import qt diff --git a/package/debian10/control b/package/debian10/control index e98ce65..d724e69 100644 --- a/package/debian10/control +++ b/package/debian10/control @@ -11,7 +11,7 @@ Build-Depends: cython3 (>= 0.23.2), graphviz, help2man, ipython3, - ipython3-qtconsole, + python3-qtconsole, pandoc , python3-all-dev, python3-dateutil, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c80dee7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "wheel", + "setuptools", + "numpy>=1.12", + "Cython>=0.21.1" +] diff --git a/run_tests.py b/run_tests.py index 6007344..5d3155a 100755 --- a/run_tests.py +++ b/run_tests.py @@ -1,8 +1,8 @@ -#!/usr/bin/env python -# coding: utf-8 +#!/usr/bin/env python3 +# coding: utf8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,7 +32,7 @@ Test coverage dependencies: coverage, lxml. """ __authors__ = ["Jérôme Kieffer", "Thomas Vincent"] -__date__ = "02/03/2018" +__date__ = "30/09/2020" __license__ = "MIT" import distutils.util @@ -87,7 +87,6 @@ logger.setLevel(logging.WARNING) logger.info("Python %s %s", sys.version, tuple.__itemsize__ * 8) - try: import resource except ImportError: @@ -98,6 +97,7 @@ try: import importlib importer = importlib.import_module except ImportError: + def importer(name): module = __import__(name) # returns the leaf module, instead of the root module @@ -107,7 +107,6 @@ except ImportError: module = getattr(module, subname) return module - try: import numpy except Exception as error: @@ -350,11 +349,9 @@ if __name__ == "__main__": # Needed for multiprocessing support on Windows PROJECT_VERSION = getattr(project_module, 'version', '') PROJECT_PATH = project_module.__path__[0] - test_options = get_test_options(project_module) """Contains extra configuration for the tests.""" - epilog = """Environment variables: WITH_QT_TEST=False to disable graphical tests SILX_OPENCL=False to disable OpenCL tests @@ -393,7 +390,6 @@ if __name__ == "__main__": # Needed for multiprocessing support on Windows options = parser.parse_args() sys.argv = [sys.argv[0]] - test_verbosity = 1 use_buffer = True if options.verbose == 1: @@ -467,7 +463,6 @@ if __name__ == "__main__": # Needed for multiprocessing support on Windows else: logger.warning("No test options available.") - if not options.test_name: # Do not use test loader to avoid cryptic exception # when an error occur during import @@ -487,7 +482,6 @@ if __name__ == "__main__": # Needed for multiprocessing support on Windows else: exit_status = 1 - if options.coverage: cov.stop() cov.save() diff --git a/silx.egg-info/PKG-INFO b/silx.egg-info/PKG-INFO index 74f97d5..6bf8a6f 100644 --- a/silx.egg-info/PKG-INFO +++ b/silx.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: silx -Version: 0.13.1 +Version: 0.14.0 Summary: Software library for X-ray data analysis Home-page: http://www.silx.org/ Author: data analysis unit @@ -108,8 +108,8 @@ Description: *silx* releases can be cited via their DOI on Zenodo: |zenodo DOI| - .. |Travis Status| image:: https://travis-ci.org/silx-kit/silx.svg?branch=master - :target: https://travis-ci.org/silx-kit/silx?branch=master + .. |Travis Status| image:: https://travis-ci.com/silx-kit/silx.svg?branch=master + :target: https://travis-ci.com/silx-kit/silx .. |Appveyor Status| image:: https://ci.appveyor.com/api/projects/status/qgox9ei0wxwfagrb/branch/master?svg=true :target: https://ci.appveyor.com/project/ESRF/silx?branch=master .. |zenodo DOI| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.591709.svg diff --git a/silx.egg-info/SOURCES.txt b/silx.egg-info/SOURCES.txt index 3e27a4d..887aaf1 100644 --- a/silx.egg-info/SOURCES.txt +++ b/silx.egg-info/SOURCES.txt @@ -4,6 +4,7 @@ MANIFEST.in README.rst build-deb.sh copyright +pyproject.toml requirements-dev.txt requirements.txt run_tests.py @@ -134,6 +135,7 @@ doc/source/modules/gui/plot/plotwindow.rst doc/source/modules/gui/plot/printpreviewtoolbutton.rst doc/source/modules/gui/plot/profile.rst doc/source/modules/gui/plot/roi.rst +doc/source/modules/gui/plot/roistatswidget.rst doc/source/modules/gui/plot/scatterview.rst doc/source/modules/gui/plot/stackview.rst doc/source/modules/gui/plot/statswidget.rst @@ -160,6 +162,7 @@ doc/source/modules/gui/plot/img/Plot2D.png doc/source/modules/gui/plot/img/PlotWidget.png doc/source/modules/gui/plot/img/PlotWindow.png doc/source/modules/gui/plot/img/PositionInfo.png +doc/source/modules/gui/plot/img/ROIStatsWidget.png doc/source/modules/gui/plot/img/ScatterView.png doc/source/modules/gui/plot/img/StackView.png doc/source/modules/gui/plot/img/StackViewMainWindow.png @@ -303,6 +306,7 @@ doc/source/sample_code/img/plotInteractiveImageROI.png doc/source/sample_code/img/plotItemsSelector.png doc/source/sample_code/img/plotLimits.png doc/source/sample_code/img/plotProfile.png +doc/source/sample_code/img/plotROIStats.png doc/source/sample_code/img/plotStats.png doc/source/sample_code/img/plotUpdateCurveFromThread.png doc/source/sample_code/img/plotUpdateImageFromThread.png @@ -344,6 +348,7 @@ examples/plotInteractiveImageROI.py examples/plotItemsSelector.py examples/plotLimits.py examples/plotProfile.py +examples/plotROIStats.py examples/plotStats.py examples/plotUpdateCurveFromThread.py examples/plotUpdateImageFromThread.py @@ -541,6 +546,7 @@ silx/gui/plot/PlotWindow.py silx/gui/plot/PrintPreviewToolButton.py silx/gui/plot/Profile.py silx/gui/plot/ProfileMainWindow.py +silx/gui/plot/ROIStatsWidget.py silx/gui/plot/ScatterMaskToolsWidget.py silx/gui/plot/ScatterView.py silx/gui/plot/StackView.py @@ -573,6 +579,7 @@ silx/gui/plot/backends/__init__.py silx/gui/plot/backends/glutils/GLPlotCurve.py silx/gui/plot/backends/glutils/GLPlotFrame.py silx/gui/plot/backends/glutils/GLPlotImage.py +silx/gui/plot/backends/glutils/GLPlotItem.py silx/gui/plot/backends/glutils/GLPlotTriangles.py silx/gui/plot/backends/glutils/GLSupport.py silx/gui/plot/backends/glutils/GLText.py @@ -580,7 +587,9 @@ silx/gui/plot/backends/glutils/GLTexture.py silx/gui/plot/backends/glutils/PlotImageFile.py silx/gui/plot/backends/glutils/__init__.py silx/gui/plot/items/__init__.py +silx/gui/plot/items/_arc_roi.py silx/gui/plot/items/_pick.py +silx/gui/plot/items/_roi_base.py silx/gui/plot/items/axis.py silx/gui/plot/items/complex.py silx/gui/plot/items/core.py @@ -614,6 +623,7 @@ silx/gui/plot/test/testPlotInteraction.py silx/gui/plot/test/testPlotWidget.py silx/gui/plot/test/testPlotWidgetNoBackend.py silx/gui/plot/test/testPlotWindow.py +silx/gui/plot/test/testRoiStatsWidget.py silx/gui/plot/test/testSaveAction.py silx/gui/plot/test/testScatterMaskToolsWidget.py silx/gui/plot/test/testScatterView.py @@ -720,8 +730,10 @@ silx/gui/utils/__init__.py silx/gui/utils/concurrent.py silx/gui/utils/glutils.py silx/gui/utils/image.py +silx/gui/utils/matplotlib.py silx/gui/utils/projecturl.py silx/gui/utils/qtutils.py +silx/gui/utils/signal.py silx/gui/utils/testutils.py silx/gui/utils/test/__init__.py silx/gui/utils/test/test.py @@ -756,6 +768,7 @@ silx/gui/widgets/test/test_elidedlabel.py silx/gui/widgets/test/test_flowlayout.py silx/gui/widgets/test/test_framebrowser.py silx/gui/widgets/test/test_hierarchicaltableview.py +silx/gui/widgets/test/test_legendiconwidget.py silx/gui/widgets/test/test_periodictable.py silx/gui/widgets/test/test_printpreview.py silx/gui/widgets/test/test_rangeslider.py @@ -1010,10 +1023,14 @@ silx/resources/gui/icons/add-shape-unknown.png silx/resources/gui/icons/add-shape-unknown.svg silx/resources/gui/icons/add-shape-vertical.png silx/resources/gui/icons/add-shape-vertical.svg +silx/resources/gui/icons/add.png +silx/resources/gui/icons/add.svg silx/resources/gui/icons/arrow-keys.png silx/resources/gui/icons/arrow-keys.svg silx/resources/gui/icons/axis.png silx/resources/gui/icons/axis.svg +silx/resources/gui/icons/backend-opengl.png +silx/resources/gui/icons/backend-opengl.svg silx/resources/gui/icons/camera.png silx/resources/gui/icons/camera.svg silx/resources/gui/icons/clipboard.png @@ -1261,6 +1278,8 @@ silx/resources/gui/icons/profile2D.png silx/resources/gui/icons/profile2D.svg silx/resources/gui/icons/remove.png silx/resources/gui/icons/remove.svg +silx/resources/gui/icons/rm.png +silx/resources/gui/icons/rm.svg silx/resources/gui/icons/rotate-3d.png silx/resources/gui/icons/rotate-3d.svg silx/resources/gui/icons/rudder.png @@ -1474,7 +1493,7 @@ silx/third_party/_local/scipy_spatial/qhull/src/userprintf_r.c silx/third_party/_local/scipy_spatial/qhull/src/userprintf_rbox_r.c silx/utils/ExternalResources.py silx/utils/__init__.py -silx/utils/_have_openmp.pxi +silx/utils/_have_openmp.pxd silx/utils/array_like.py silx/utils/debug.py silx/utils/deprecation.py diff --git a/silx/app/test/test_convert.py b/silx/app/test/test_convert.py index bb1ae99..857f30c 100644 --- a/silx/app/test/test_convert.py +++ b/silx/app/test/test_convert.py @@ -40,7 +40,7 @@ import h5py import silx from .. import convert from silx.utils import testutils - +from silx.io.utils import h5py_read_dataset # content of a spec file @@ -137,7 +137,7 @@ class TestConvertCommand(unittest.TestCase): self.assertTrue(os.path.isfile(h5name)) with h5py.File(h5name, "r") as h5f: - title12 = h5f["/1.2/title"][()] + title12 = h5py_read_dataset(h5f["/1.2/title"]) if sys.version_info < (3, ): title12 = title12.encode("utf-8") self.assertEqual(title12, diff --git a/silx/app/view/Viewer.py b/silx/app/view/Viewer.py index 9503533..dd4d075 100644 --- a/silx/app/view/Viewer.py +++ b/silx/app/view/Viewer.py @@ -116,6 +116,8 @@ class Viewer(qt.QMainWindow): spliter.addWidget(rightPanel) spliter.addWidget(self.__dataPanel) spliter.setStretchFactor(1, 1) + spliter.setCollapsible(0, False) + spliter.setCollapsible(1, False) self.__splitter = spliter main_panel = qt.QWidget(self) diff --git a/silx/app/view/main.py b/silx/app/view/main.py index c7afc19..a1369c1 100644 --- a/silx/app/view/main.py +++ b/silx/app/view/main.py @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2016-2019 European Synchrotron Radiation Facility +# Copyright (C) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,7 @@ def createParser(): default=False, help='Use OpenGL for plots (instead of matplotlib)') parser.add_argument( - '--fresh', + '-f', '--fresh', dest="fresh_preferences", action="store_true", default=False, @@ -104,7 +104,7 @@ def mainQt(options): from silx.gui import qt # Make sure matplotlib is configured # Needed for Debian 8: compatibility between Qt4/Qt5 and old matplotlib - from silx.gui.plot import matplotlib + import silx.gui.utils.matplotlib # noqa app = qt.QApplication([]) qt.QLocale.setDefault(qt.QLocale.c()) diff --git a/silx/gui/_glutils/FramebufferTexture.py b/silx/gui/_glutils/FramebufferTexture.py index cc05080..e065030 100644 --- a/silx/gui/_glutils/FramebufferTexture.py +++ b/silx/gui/_glutils/FramebufferTexture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2017 European Synchrotron Radiation Facility +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -62,6 +62,7 @@ class FramebufferTexture(object): **kwargs): self._texture = Texture(internalFormat, shape=shape, **kwargs) + self._texture.prepare() self._previousFramebuffer = 0 # Used by with statement diff --git a/silx/gui/_glutils/OpenGLWidget.py b/silx/gui/_glutils/OpenGLWidget.py index 1f7bfae..5e3fcb8 100644 --- a/silx/gui/_glutils/OpenGLWidget.py +++ b/silx/gui/_glutils/OpenGLWidget.py @@ -329,6 +329,20 @@ class OpenGLWidget(qt.QWidget): else: return self.__openGLWidget.getDevicePixelRatio() + def getDotsPerInch(self): + """Returns current screen resolution as device pixels per inch. + + :rtype: float + """ + screen = self.window().windowHandle().screen() + if screen is not None: + # TODO check if this is correct on different OS/screen + # OK on macOS10.12/qt5.13.2 + dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio() + else: # Fallback + dpi = 96. * self.getDevicePixelRatio() + return dpi + def getOpenGLVersion(self): """Returns the available OpenGL version. diff --git a/silx/gui/_glutils/Texture.py b/silx/gui/_glutils/Texture.py index a7fd44b..c72135a 100644 --- a/silx/gui/_glutils/Texture.py +++ b/silx/gui/_glutils/Texture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2019 European Synchrotron Radiation Facility +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -81,20 +81,23 @@ class Texture(object): else: shape = data.shape + self._deferredUpdates = [(format_, data, None)] + assert len(shape) in (2, 3) self._shape = tuple(shape) self._ndim = len(shape) self.texUnit = texUnit - self._name = gl.glGenTextures(1) - self.bind(self.texUnit) + self._texParameterUpdates = {} # Store texture params to update + + self._minFilter = minFilter if minFilter is not None else gl.GL_NEAREST + self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = self._minFilter - self._minFilter = None - self.minFilter = minFilter if minFilter is not None else gl.GL_NEAREST + self._magFilter = magFilter if magFilter is not None else gl.GL_LINEAR + self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = self._magFilter - self._magFilter = None - self.magFilter = magFilter if magFilter is not None else gl.GL_LINEAR + self._name = None # Store texture ID if wrap is not None: if not isinstance(wrap, abc.Iterable): @@ -102,69 +105,10 @@ class Texture(object): assert len(wrap) == self.ndim - gl.glTexParameter(self.target, - gl.GL_TEXTURE_WRAP_S, - wrap[-1]) - gl.glTexParameter(self.target, - gl.GL_TEXTURE_WRAP_T, - wrap[-2]) + self._texParameterUpdates[gl.GL_TEXTURE_WRAP_S] = wrap[-1] + self._texParameterUpdates[gl.GL_TEXTURE_WRAP_T] = wrap[-2] if self.ndim == 3: - gl.glTexParameter(self.target, - gl.GL_TEXTURE_WRAP_R, - wrap[0]) - - gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) - - # This are the defaults, useless to set if not modified - # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0) - # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0) - - if data is None: - data = c_void_p(0) - type_ = gl.GL_UNSIGNED_BYTE - else: - type_ = utils.numpyToGLType(data.dtype) - - if self.ndim == 2: - _logger.debug( - 'Creating 2D texture shape: (%d, %d),' - ' internal format: %s, format: %s, type: %s', - self.shape[0], self.shape[1], - str(self.internalFormat), str(format_), str(type_)) - - gl.glTexImage2D( - gl.GL_TEXTURE_2D, - 0, - self.internalFormat, - self.shape[1], - self.shape[0], - 0, - format_, - type_, - data) - else: - _logger.debug( - 'Creating 3D texture shape: (%d, %d, %d),' - ' internal format: %s, format: %s, type: %s', - self.shape[0], self.shape[1], self.shape[2], - str(self.internalFormat), str(format_), str(type_)) - - gl.glTexImage3D( - gl.GL_TEXTURE_3D, - 0, - self.internalFormat, - self.shape[2], - self.shape[1], - self.shape[0], - 0, - format_, - type_, - data) - - gl.glBindTexture(self.target, 0) + self._texParameterUpdates[gl.GL_TEXTURE_WRAP_R] = wrap[0] @property def target(self): @@ -188,12 +132,11 @@ class Texture(object): @property def name(self): - """OpenGL texture name""" - if self._name is not None: - return self._name - else: - raise RuntimeError( - "No OpenGL texture resource, discard has already been called") + """OpenGL texture name. + + It is None if not initialized or already discarded. + """ + return self._name @property def minFilter(self): @@ -204,10 +147,7 @@ class Texture(object): def minFilter(self, minFilter): if minFilter != self.minFilter: self._minFilter = minFilter - self.bind() - gl.glTexParameter(self.target, - gl.GL_TEXTURE_MIN_FILTER, - self.minFilter) + self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = minFilter @property def magFilter(self): @@ -218,20 +158,112 @@ class Texture(object): def magFilter(self, magFilter): if magFilter != self.magFilter: self._magFilter = magFilter - self.bind() - gl.glTexParameter(self.target, - gl.GL_TEXTURE_MAG_FILTER, - self.magFilter) + self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = magFilter - def discard(self): - """Delete associated OpenGL texture""" - if self._name is not None: - gl.glDeleteTextures(self._name) - self._name = None - else: - _logger.warning("Discard as already been called") + def _isPrepareRequired(self) -> bool: + """Returns True if OpenGL texture needs to be updated. - def bind(self, texUnit=None): + :rtype: bool + """ + return (self._name is None or + self._texParameterUpdates or + self._deferredUpdates) + + def _prepareAndBind(self, texUnit=None): + """Synchronizes the OpenGL texture""" + if self._name is None: + self._name = gl.glGenTextures(1) + + self._bind(texUnit) + + # Synchronizes texture parameters + for pname, param in self._texParameterUpdates.items(): + gl.glTexParameter(self.target, pname, param) + self._texParameterUpdates = {} + + # Copy data to texture + for format_, data, offset in self._deferredUpdates: + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + + # This are the defaults, useless to set if not modified + # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0) + # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0) + # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0) + + if data is None: + data = c_void_p(0) + type_ = gl.GL_UNSIGNED_BYTE + else: + type_ = utils.numpyToGLType(data.dtype) + + if offset is None: # Initialize texture + if self.ndim == 2: + _logger.debug( + 'Creating 2D texture shape: (%d, %d),' + ' internal format: %s, format: %s, type: %s', + self.shape[0], self.shape[1], + str(self.internalFormat), str(format_), str(type_)) + + gl.glTexImage2D( + gl.GL_TEXTURE_2D, + 0, + self.internalFormat, + self.shape[1], + self.shape[0], + 0, + format_, + type_, + data) + + else: + _logger.debug( + 'Creating 3D texture shape: (%d, %d, %d),' + ' internal format: %s, format: %s, type: %s', + self.shape[0], self.shape[1], self.shape[2], + str(self.internalFormat), str(format_), str(type_)) + + gl.glTexImage3D( + gl.GL_TEXTURE_3D, + 0, + self.internalFormat, + self.shape[2], + self.shape[1], + self.shape[0], + 0, + format_, + type_, + data) + + else: # Update already existing texture + if self.ndim == 2: + gl.glTexSubImage2D(gl.GL_TEXTURE_2D, + 0, + offset[1], + offset[0], + data.shape[1], + data.shape[0], + format_, + type_, + data) + + else: + gl.glTexSubImage3D(gl.GL_TEXTURE_3D, + 0, + offset[2], + offset[1], + offset[0], + data.shape[2], + data.shape[1], + data.shape[0], + format_, + type_, + data) + + self._deferredUpdates = [] + + def _bind(self, texUnit=None): """Bind the texture to a texture unit. :param int texUnit: The texture unit to use @@ -241,73 +273,80 @@ class Texture(object): gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit) gl.glBindTexture(self.target, self.name) + def _unbind(self, texUnit=None): + """Reset texture binding to a texture unit. + + :param int texUnit: The texture unit to use + """ + if texUnit is None: + texUnit = self.texUnit + gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit) + gl.glBindTexture(self.target, 0) + + def prepare(self): + """Synchronizes the OpenGL texture. + + This method must be called with a current OpenGL context. + """ + if self._isPrepareRequired(): + self._prepareAndBind() + self._unbind() + + def bind(self, texUnit=None): + """Bind the texture to a texture unit. + + The OpenGL texture is updated if needed. + + This method must be called with a current OpenGL context. + + :param int texUnit: The texture unit to use + """ + if self._isPrepareRequired(): + self._prepareAndBind(texUnit) + else: + self._bind(texUnit) + + def discard(self): + """Delete associated OpenGL texture. + + This method must be called with a current OpenGL context. + """ + if self._name is not None: + gl.glDeleteTextures(self._name) + self._name = None + else: + _logger.warning("Texture not initialized or already discarded") + # with statement def __enter__(self): self.bind() def __exit__(self, exc_type, exc_val, exc_tb): - gl.glActiveTexture(gl.GL_TEXTURE0 + self.texUnit) - gl.glBindTexture(self.target, 0) + self._unbind() - def update(self, - format_, - data, - offset=(0, 0, 0), - texUnit=None): + def update(self, format_, data, offset=(0, 0, 0), copy=True): """Update the content of the texture. Texture is not resized, so data must fit into texture with the given offset. + This update is performed lazily during next call to + :meth:`prepare` or :meth:`bind`. + Data MUST not be changed until then. + :param format_: The OpenGL format of the data :param data: The data to use to update the texture - :param offset: The offset in the texture where to copy the data - :type offset: List[int] - :param int texUnit: - The texture unit to use (default: the one provided at init) + :param List[int] offset: Offset in the texture where to copy the data + :param bool copy: + True (default) to copy data, False to use as is (do not modify) """ - data = numpy.array(data, copy=False, order='C') + data = numpy.array(data, copy=copy, order='C') + offset = tuple(offset) assert data.ndim == self.ndim assert len(offset) >= self.ndim for i in range(self.ndim): assert offset[i] + data.shape[i] <= self.shape[i] - gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) - - # This are the defaults, useless to set if not modified - # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0) - # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0) - # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0) - - self.bind(texUnit) - - type_ = utils.numpyToGLType(data.dtype) - - if self.ndim == 2: - gl.glTexSubImage2D(gl.GL_TEXTURE_2D, - 0, - offset[1], - offset[0], - data.shape[1], - data.shape[0], - format_, - type_, - data) - gl.glBindTexture(gl.GL_TEXTURE_2D, 0) - else: - gl.glTexSubImage3D(gl.GL_TEXTURE_3D, - 0, - offset[2], - offset[1], - offset[0], - data.shape[2], - data.shape[1], - data.shape[0], - format_, - type_, - data) - gl.glBindTexture(gl.GL_TEXTURE_3D, 0) + self._deferredUpdates.append((format_, data, offset)) diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py index 35cf819..d5627ef 100644 --- a/silx/gui/_glutils/utils.py +++ b/silx/gui/_glutils/utils.py @@ -29,45 +29,25 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "10/01/2017" -from . import gl import numpy - -_GL_TYPE_SIZES = { - gl.GL_FLOAT: 4, - gl.GL_BYTE: 1, - gl.GL_SHORT: 2, - gl.GL_INT: 4, - gl.GL_UNSIGNED_BYTE: 1, - gl.GL_UNSIGNED_SHORT: 2, - gl.GL_UNSIGNED_INT: 4, -} +from OpenGL.constants import BYTE_SIZES as _BYTE_SIZES +from OpenGL.constants import ARRAY_TO_GL_TYPE_MAPPING as _ARRAY_TO_GL_TYPE_MAPPING def sizeofGLType(type_): """Returns the size in bytes of an element of type `type_`""" - return _GL_TYPE_SIZES[type_] - - -_TYPE_CONVERTER = { - numpy.dtype(numpy.float32): gl.GL_FLOAT, - numpy.dtype(numpy.int8): gl.GL_BYTE, - numpy.dtype(numpy.int16): gl.GL_SHORT, - numpy.dtype(numpy.int32): gl.GL_INT, - numpy.dtype(numpy.uint8): gl.GL_UNSIGNED_BYTE, - numpy.dtype(numpy.uint16): gl.GL_UNSIGNED_SHORT, - numpy.dtype(numpy.uint32): gl.GL_UNSIGNED_INT, -} + return _BYTE_SIZES[type_] def isSupportedGLType(type_): """Test if a numpy type or dtype can be converted to a GL type.""" - return numpy.dtype(type_) in _TYPE_CONVERTER + return numpy.dtype(type_).char in _ARRAY_TO_GL_TYPE_MAPPING def numpyToGLType(type_): """Returns the GL type corresponding the provided numpy type or dtype.""" - return _TYPE_CONVERTER[numpy.dtype(type_)] + return _ARRAY_TO_GL_TYPE_MAPPING[numpy.dtype(type_).char] def segmentTrianglesIntersection(segment, triangles): diff --git a/silx/gui/colors.py b/silx/gui/colors.py index 4d750ba..4a96ae0 100755 --- a/silx/gui/colors.py +++ b/silx/gui/colors.py @@ -34,7 +34,10 @@ __date__ = "29/01/2019" import numpy import logging import collections +import warnings + from silx.gui import qt +from silx.gui.utils import blockSignals from silx.math.combo import min_max from silx.math import colormap as _colormap from silx.utils.exceptions import NotEditableError @@ -45,10 +48,13 @@ from silx.resources import resource_filename as _resource_filename _logger = logging.getLogger(__file__) try: + import silx.gui.utils.matplotlib # noqa Initalize matplotlib from matplotlib import cm as _matplotlib_cm + from matplotlib.pyplot import colormaps as _matplotlib_colormaps except ImportError: _logger.info("matplotlib not available, only embedded colormaps available") _matplotlib_cm = None + _matplotlib_colormaps = None _COLORDICT = {} @@ -362,7 +368,22 @@ class _NormalizationMixIn: if mode == Colormap.MINMAX: vmin, vmax = self.autoscaleMinMax(data) elif mode == Colormap.STDDEV3: - vmin, vmax = self.autoscaleMean3Std(data) + dmin, dmax = self.autoscaleMinMax(data) + stdmin, stdmax = self.autoscaleMean3Std(data) + if dmin is None: + vmin = stdmin + elif stdmin is None: + vmin = dmin + else: + vmin = max(dmin, stdmin) + + if dmax is None: + vmax = stdmax + elif stdmax is None: + vmax = dmax + else: + vmax = min(dmax, stdmax) + else: raise ValueError('Unsupported mode: %s' % mode) @@ -405,7 +426,13 @@ class _NormalizationMixIn: normdata[numpy.isfinite(normdata) == False] = numpy.nan if normdata.size == 0: # Fallback return None, None - mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata) + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore nanmean "Mean of empty slice" warning and + # nanstd "Degrees of freedom <= 0 for slice" warning + mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata) + return self.revert(mean - 3 * std, 0., 1.), self.revert(mean + 3 * std, 0., 1.) @@ -426,7 +453,11 @@ class _LinearNormalizationMixIn(_NormalizationMixIn): data[numpy.isfinite(data) == False] = numpy.nan if data.size == 0: # Fallback return None, None - mean, std = numpy.nanmean(data), numpy.nanstd(data) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + # Ignore nanmean "Mean of empty slice" warning and + # nanstd "Degrees of freedom <= 0 for slice" warning + mean, std = numpy.nanmean(data), numpy.nanstd(data) return mean - 3 * std, mean + 3 * std @@ -534,7 +565,8 @@ class Colormap(qt.QObject): """constant for autoscale using min/max data range""" STDDEV3 = 'stddev3' - """constant for autoscale using mean +/- 3*std(data)""" + """constant for autoscale using mean +/- 3*std(data) + with a clamp on min/max of the data""" AUTOSCALE_MODES = (MINMAX, STDDEV3) """Tuple of managed auto scale algorithms""" @@ -542,10 +574,14 @@ class Colormap(qt.QObject): sigChanged = qt.Signal() """Signal emitted when the colormap has changed.""" + _DEFAULT_NAN_COLOR = 255, 255, 255, 0 + def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX): qt.QObject.__init__(self) self._editable = True self.__gamma = 2.0 + # Default NaN color: fully transparent white + self.__nanColor = numpy.array(self._DEFAULT_NAN_COLOR, dtype=numpy.uint8) assert normalization in Colormap.NORMALIZATIONS assert autoscaleMode in Colormap.AUTOSCALE_MODES @@ -593,15 +629,19 @@ class Colormap(qt.QObject): raise NotEditableError('Colormap is not editable') if self == other: return - old = self.blockSignals(True) - name = other.getName() - if name is not None: - self.setName(name) - else: - self.setColormapLUT(other.getColormapLUT()) - self.setNormalization(other.getNormalization()) - self.setVRange(other.getVMin(), other.getVMax()) - self.blockSignals(old) + with blockSignals(self): + name = other.getName() + if name is not None: + self.setName(name) + else: + self.setColormapLUT(other.getColormapLUT()) + self.setNaNColor(other.getNaNColor()) + self.setNormalization(other.getNormalization()) + self.setGammaNormalizationParameter( + other.getGammaNormalizationParameter()) + self.setAutoscaleMode(other.getAutoscaleMode()) + self.setVRange(*other.getVRange()) + self.setEditable(other.isEditable()) self.sigChanged.emit() def getNColors(self, nbColors=None): @@ -623,7 +663,7 @@ class Colormap(qt.QObject): colormap.setNormalization(Colormap.LINEAR) colormap.setVRange(vmin=0, vmax=nbColors - 1) colors = colormap.applyToData( - numpy.arange(nbColors, dtype=numpy.int)) + numpy.arange(nbColors, dtype=numpy.int32)) return colors def getName(self): @@ -689,6 +729,24 @@ class Colormap(qt.QObject): self._name = None self.sigChanged.emit() + def getNaNColor(self): + """Returns the color to use for Not-A-Number floating point value. + + :rtype: QColor + """ + return qt.QColor(*self.__nanColor) + + def setNaNColor(self, color): + """Set the color to use for Not-A-Number floating point value. + + :param color: RGB(A) color to use for NaN values + :type color: QColor, str, tuple of uint8 or float in [0., 1.] + """ + color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8) + if not numpy.array_equal(self.__nanColor, color): + self.__nanColor = color + self.sigChanged.emit() + def getNormalization(self): """Return the normalization of the colormap. @@ -1021,8 +1079,10 @@ class Colormap(qt.QObject): vmax=self._vmax, normalization=self.getNormalization(), autoscaleMode=self.getAutoscaleMode()) + colormap.setNaNColor(self.getNaNColor()) colormap.setGammaNormalizationParameter( self.getGammaNormalizationParameter()) + colormap.setEditable(self.isEditable()) return colormap def applyToData(self, data, reference=None): @@ -1038,10 +1098,15 @@ class Colormap(qt.QObject): vmin, vmax = self.getColormapRange(reference) if hasattr(data, "getColormappedData"): # Use item's data - data = data.getColormappedData() + data = data.getColormappedData(copy=False) return _colormap.cmap( - data, self._colors, vmin, vmax, self._getNormalizer()) + data, + self._colors, + vmin, + vmax, + self._getNormalizer(), + self.__nanColor) @staticmethod def getSupportedColormaps(): @@ -1055,8 +1120,8 @@ class Colormap(qt.QObject): :rtype: tuple """ colormaps = set() - if _matplotlib_cm is not None: - colormaps.update(_matplotlib_cm.cmap_d.keys()) + if _matplotlib_colormaps is not None: + colormaps.update(_matplotlib_colormaps()) colormaps.update(_AVAILABLE_LUTS.keys()) colormaps = tuple(cmap for cmap in sorted(colormaps) @@ -1086,7 +1151,7 @@ class Colormap(qt.QObject): numpy.array_equal(self.getColormapLUT(), other.getColormapLUT()) ) - _SERIAL_VERSION = 2 + _SERIAL_VERSION = 3 def restoreState(self, byteArray): """ @@ -1106,7 +1171,7 @@ class Colormap(qt.QObject): return False version = stream.readUInt32() - if version not in (1, self._SERIAL_VERSION): + if version not in numpy.arange(1, self._SERIAL_VERSION+1): _logger.warning("Serial version mismatch. Found %d." % version) return False @@ -1133,6 +1198,11 @@ class Colormap(qt.QObject): else: autoscaleMode = stream.readQString() + if version <= 2: + nanColor = self._DEFAULT_NAN_COLOR + else: + nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32() + # emit change event only once old = self.blockSignals(True) try: @@ -1142,6 +1212,7 @@ class Colormap(qt.QObject): self.setVRange(vmin, vmax) if gamma is not None: self.setGammaNormalizationParameter(gamma) + self.setNaNColor(nanColor) finally: self.blockSignals(old) self.sigChanged.emit() @@ -1169,6 +1240,12 @@ class Colormap(qt.QObject): if self.getNormalization() == Colormap.GAMMA: stream.writeFloat(self.getGammaNormalizationParameter()) stream.writeQString(self.getAutoscaleMode()) + nanColor = self.getNaNColor() + stream.writeInt32(nanColor.red()) + stream.writeInt32(nanColor.green()) + stream.writeInt32(nanColor.blue()) + stream.writeInt32(nanColor.alpha()) + return data diff --git a/silx/gui/data/DataViews.py b/silx/gui/data/DataViews.py index f3b02b9..d9958de 100644 --- a/silx/gui/data/DataViews.py +++ b/silx/gui/data/DataViews.py @@ -406,7 +406,7 @@ class DataView(object): :param NamedTuple selection: Data selected :rtype: str """ - if selection is None: + if selection is None or selection.filename is None: return None else: directory, filename = os.path.split(selection.filename) diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py index 57d6f7b..7749326 100644 --- a/silx/gui/data/Hdf5TableView.py +++ b/silx/gui/data/Hdf5TableView.py @@ -380,37 +380,87 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel): SEPARATOR = "::" self.__data.addHeaderRow(headerLabel="Path info") + showPhysicalLocation = True if isinstance(obj, silx.gui.hdf5.H5Node): # helpful informations if the object come from an HDF5 tree self.__data.addHeaderValueRow("Basename", lambda x: x.local_basename) self.__data.addHeaderValueRow("Name", lambda x: x.local_name) local = lambda x: x.local_filename + SEPARATOR + x.local_name self.__data.addHeaderValueRow("Local", local) - physical = lambda x: x.physical_filename + SEPARATOR + x.physical_name - self.__data.addHeaderValueRow("Physical", physical) else: # it's a real H5py object self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name)) self.__data.addHeaderValueRow("Name", lambda x: x.name) if obj.file is not None: self.__data.addHeaderValueRow("File", lambda x: x.file.filename) - if hasattr(obj, "path"): # That's a link if hasattr(obj, "filename"): + # External link link = lambda x: x.filename + SEPARATOR + x.path else: + # Soft link link = lambda x: x.path self.__data.addHeaderValueRow("Link", link) - else: - if silx.io.is_file(obj): - physical = lambda x: x.filename + SEPARATOR + x.name + showPhysicalLocation = False + + # External data (nothing to do with external links) + nExtSources = 0 + firstExtSource = None + extType = None + if silx.io.is_dataset(hdf5obj): + if hasattr(hdf5obj, "is_virtual"): + if hdf5obj.is_virtual: + extSources = hdf5obj.virtual_sources() + if extSources: + firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name + extType = "Virtual" + nExtSources = len(extSources) + if hasattr(hdf5obj, "external"): + extSources = hdf5obj.external + if extSources: + firstExtSource = extSources[0][0] + extType = "Raw" + nExtSources = len(extSources) + + if showPhysicalLocation: + def _physical_location(x): + if isinstance(obj, silx.gui.hdf5.H5Node): + return x.physical_filename + SEPARATOR + x.physical_name + elif silx.io.is_file(obj): + return x.filename + SEPARATOR + x.name elif obj.file is not None: - physical = lambda x: x.file.filename + SEPARATOR + x.name + return x.file.filename + SEPARATOR + x.name else: # Guess it is a virtual node - physical = "No physical location" - self.__data.addHeaderValueRow("Physical", physical) + return "No physical location" + + self.__data.addHeaderValueRow("Physical", _physical_location) + + if extType: + def _first_source(x): + # Absolute path + if os.path.isabs(firstExtSource): + return firstExtSource + + # Relative path with respect to the file directory + if isinstance(obj, silx.gui.hdf5.H5Node): + filename = x.physical_filename + elif silx.io.is_file(obj): + filename = x.filename + elif obj.file is not None: + filename = x.file.filename + else: + return firstExtSource + + if firstExtSource[0] == ".": + firstExtSource.pop(0) + return os.path.join(os.path.dirname(filename), firstExtSource) + + self.__data.addHeaderRow(headerLabel="External sources") + self.__data.addHeaderValueRow("Type", extType) + self.__data.addHeaderValueRow("Count", str(nExtSources)) + self.__data.addHeaderValueRow("First", _first_source) if hasattr(obj, "dtype"): diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py index 224f337..271b267 100644 --- a/silx/gui/data/NXdataWidgets.py +++ b/silx/gui/data/NXdataWidgets.py @@ -370,6 +370,7 @@ class ArrayImagePlot(qt.QWidget): vmin=None, vmax=None, normalization=Colormap.LINEAR)) self._plot.getIntensityHistogramAction().setVisible(True) + self._plot.setKeepDataAspectRatio(True) # not closable self._selector = NumpyAxesSelector(self) diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py index 98c37d7..8fd7c7c 100644 --- a/silx/gui/data/TextFormatter.py +++ b/silx/gui/data/TextFormatter.py @@ -267,6 +267,12 @@ class TextFormatter(qt.QObject): if vlen is not None: if vlen == six.text_type: # HDF5 UTF8 + # With h5py>=3 reading dataset returns bytes + if isinstance(data, (bytes, numpy.bytes_)): + try: + data = data.decode("utf-8") + except UnicodeDecodeError: + self.__formatSafeAscii(data) return self.__formatText(data) elif vlen == six.binary_type: # HDF5 ASCII @@ -289,7 +295,7 @@ class TextFormatter(qt.QObject): elif isinstance(data, list): text = [self.toString(d) for d in data] return "[" + " ".join(text) + "]" - elif isinstance(data, (numpy.ndarray)): + elif isinstance(data, numpy.ndarray): if dtype is None: dtype = data.dtype if data.shape == (): diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py index 12a640e..dd01dd6 100644 --- a/silx/gui/data/test/test_dataviewer.py +++ b/silx/gui/data/test/test_dataviewer.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -108,7 +108,7 @@ class AbstractDataViewerTests(TestCaseQt): self.assertIn(DataViews.IMAGE_MODE, availableModes) def test_image_bool(self): - data = numpy.zeros((10, 10), dtype=numpy.bool) + data = numpy.zeros((10, 10), dtype=bool) data[::2, ::2] = True widget = self.create_widget() widget.setData(data) @@ -117,7 +117,7 @@ class AbstractDataViewerTests(TestCaseQt): self.assertIn(DataViews.IMAGE_MODE, availableModes) def test_image_complex_data(self): - data = numpy.arange(3 ** 2, dtype=numpy.complex) + data = numpy.arange(3 ** 2, dtype=numpy.complex64) data.shape = [3] * 2 widget = self.create_widget() widget.setData(data) @@ -262,7 +262,7 @@ class TestDataView(TestCaseQt): line = [1, 2j, 3 + 3j, 4] image = [line, line, line, line] cube = [image, image, image, image] - data = numpy.array(cube, dtype=numpy.complex) + data = numpy.array(cube, dtype=numpy.complex64) return data def createDataViewWithData(self, dataViewClass, data): diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py index 1a63074..d3050bf 100644 --- a/silx/gui/data/test/test_textformatter.py +++ b/silx/gui/data/test/test_textformatter.py @@ -36,6 +36,7 @@ import six from silx.gui.utils.testutils import TestCaseQt from silx.gui.utils.testutils import SignalListener from ..TextFormatter import TextFormatter +from silx.io.utils import h5py_read_dataset import h5py @@ -123,76 +124,79 @@ class TestTextFormatterWithH5py(TestCaseQt): dataset = self.h5File.create_dataset(testName, data=data, dtype=dtype) return dataset + def read_dataset(self, d): + return self.formatter.toString(d[()], dtype=d.dtype) + def testAscii(self): d = self.create_dataset(data=b"abc") - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '"abc"') def testUnicode(self): d = self.create_dataset(data=u"i\u2661cookies") - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(len(result), 11) self.assertEqual(result, u'"i\u2661cookies"') def testBadAscii(self): d = self.create_dataset(data=b"\xF0\x9F\x92\x94") - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, 'b"\\xF0\\x9F\\x92\\x94"') def testVoid(self): d = self.create_dataset(data=numpy.void(b"abc\xF0")) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"') def testEnum(self): dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42})) d = numpy.array(42, dtype=dtype) d = self.create_dataset(data=d) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, 'BLUE(42)') def testRef(self): dtype = h5py.special_dtype(ref=h5py.Reference) d = numpy.array(self.h5File.ref, dtype=dtype) d = self.create_dataset(data=d) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, 'REF') def testArrayAscii(self): d = self.create_dataset(data=[b"abc"]) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '["abc"]') def testArrayUnicode(self): dtype = h5py.special_dtype(vlen=six.text_type) d = numpy.array([u"i\u2661cookies"], dtype=dtype) d = self.create_dataset(data=d) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(len(result), 13) self.assertEqual(result, u'["i\u2661cookies"]') def testArrayBadAscii(self): d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"]) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '[b"\\xF0\\x9F\\x92\\x94"]') def testArrayVoid(self): d = self.create_dataset(data=numpy.void([b"abc\xF0"])) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]') def testArrayEnum(self): dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42})) d = numpy.array([42, 1, 100], dtype=dtype) d = self.create_dataset(data=d) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '[BLUE(42) GREEN(1) 100]') def testArrayRef(self): dtype = h5py.special_dtype(ref=h5py.Reference) d = numpy.array([self.h5File.ref, None], dtype=dtype) d = self.create_dataset(data=d) - result = self.formatter.toString(d[()], dtype=d.dtype) + result = self.read_dataset(d) self.assertEqual(result, '[REF NULL_REF]') diff --git a/silx/gui/fit/BackgroundWidget.py b/silx/gui/fit/BackgroundWidget.py index 2171e87..76bc043 100644 --- a/silx/gui/fit/BackgroundWidget.py +++ b/silx/gui/fit/BackgroundWidget.py @@ -1,6 +1,6 @@ # coding: utf-8 #/*########################################################################## -# Copyright (C) 2004-2017 V.A. Sole, European Synchrotron Radiation Facility +# Copyright (C) 2004-2020 V.A. Sole, European Synchrotron Radiation Facility # # This file is part of the PyMca X-ray Fluorescence Toolkit developed at # the ESRF by the Software group. @@ -337,7 +337,7 @@ class BackgroundWidget(qt.QWidget): pars = self.getParameters() # smoothed data - y = numpy.ravel(numpy.array(self._y)).astype(numpy.float) + y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64) if pars["SmoothingFlag"]: ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth']) f = [0.25, 0.5, 0.25] diff --git a/silx/gui/fit/FitWidget.py b/silx/gui/fit/FitWidget.py index 7279cd9..08731f1 100644 --- a/silx/gui/fit/FitWidget.py +++ b/silx/gui/fit/FitWidget.py @@ -720,7 +720,7 @@ class FitWidget(qt.QWidget): if __name__ == "__main__": import numpy - x = numpy.arange(1500).astype(numpy.float) + x = numpy.arange(1500).astype(numpy.float64) constant_bg = 3.14 p = [1000, 100., 30.0, diff --git a/silx/gui/hdf5/Hdf5Item.py b/silx/gui/hdf5/Hdf5Item.py index 11a08b6..e07f835 100755 --- a/silx/gui/hdf5/Hdf5Item.py +++ b/silx/gui/hdf5/Hdf5Item.py @@ -100,7 +100,7 @@ class Hdf5Item(Hdf5Node): """Returns the class of the stored object. When the object is in lazy loading, this method should be able to - return the type of the futrue loaded object. It allows to delay the + return the type of the future loaded object. It allows to delay the real load of the object. :rtype: silx.io.utils.H5Type @@ -114,7 +114,7 @@ class Hdf5Item(Hdf5Node): """Returns the class of the stored object. When the object is in lazy loading, this method should be able to - return the type of the futrue loaded object. It allows to delay the + return the type of the future loaded object. It allows to delay the real load of the object. :rtype: h5py.File or h5py.Dataset or h5py.Group @@ -383,12 +383,13 @@ class Hdf5Item(Hdf5Node): text = text.strip('"') # Check NX_class formatting lower = text.lower() + formatedNX_class = "" if lower.startswith('nx'): formatedNX_class = 'NX' + lower[2:] if lower == 'nxcansas': formatedNX_class = 'NXcanSAS' # That's the only class with capital letters... if text != formatedNX_class: - _logger.error("NX_class: %s is malformed (should be %s)", + _logger.error("NX_class: '%s' is malformed (should be '%s')", text, formatedNX_class) text = formatedNX_class @@ -614,17 +615,28 @@ class Hdf5Item(Hdf5Node): if role == qt.Qt.TextAlignmentRole: return qt.Qt.AlignTop | qt.Qt.AlignLeft if role == qt.Qt.DisplayRole: + # Mark as link link = self.linkClass if link is None: - return "" + pass + elif link == silx.io.utils.H5Type.HARD_LINK: + pass elif link == silx.io.utils.H5Type.EXTERNAL_LINK: return "External" elif link == silx.io.utils.H5Type.SOFT_LINK: return "Soft" - elif link == silx.io.utils.H5Type.HARD_LINK: - return "" else: return link.__name__ + # Mark as external data + if self.h5Class == silx.io.utils.H5Type.DATASET: + obj = self.obj + if hasattr(obj, "is_virtual"): + if obj.is_virtual: + return "Virtual" + if hasattr(obj, "external"): + if obj.external: + return "ExtRaw" + return "" if role == qt.Qt.ToolTipRole: return None return None diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py index 5bd4223..fcfc02c 100755 --- a/silx/gui/hdf5/test/test_hdf5.py +++ b/silx/gui/hdf5/test/test_hdf5.py @@ -589,11 +589,11 @@ class TestNexusSortFilterProxyModel(TestCaseQt): self.assertListEqual(names, ["100aaa", "aaa100"]) -class TestH5Node(TestCaseQt): +class _TestModelBase(TestCaseQt): @classmethod def setUpClass(cls): - super(TestH5Node, cls).setUpClass() + super(_TestModelBase, cls).setUpClass() cls.tmpDirectory = tempfile.mkdtemp() cls.h5Filename = cls.createResource(cls.tmpDirectory) @@ -603,13 +603,18 @@ class TestH5Node(TestCaseQt): @classmethod def createResource(cls, directory): filename = os.path.join(directory, "base.h5") - externalFilename = os.path.join(directory, "base__external.h5") + extH5FileName = os.path.join(directory, "base__external.h5") + extDatFileName = os.path.join(directory, "base__external.dat") - externalh5 = h5py.File(externalFilename, mode="w") + externalh5 = h5py.File(extH5FileName, mode="w") externalh5["target/dataset"] = 50 externalh5["target/link"] = h5py.SoftLink("/target/dataset") + externalh5["/ext/vds0"] = [0, 1] + externalh5["/ext/vds1"] = [2, 3] externalh5.close() + numpy.array([0,1,10,10,2,3]).tofile(extDatFileName) + h5 = h5py.File(filename, mode="w") h5["group/dataset"] = 50 h5["link/soft_link"] = h5py.SoftLink("/group/dataset") @@ -617,12 +622,19 @@ class TestH5Node(TestCaseQt): h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link") h5["link/soft_link_to_file"] = h5py.SoftLink("/") h5["group/soft_link_relative"] = h5py.SoftLink("dataset") - h5["link/external_link"] = h5py.ExternalLink(externalFilename, "/target/dataset") - h5["link/external_link_to_link"] = h5py.ExternalLink(externalFilename, "/target/link") - h5["broken_link/external_broken_file"] = h5py.ExternalLink(externalFilename + "_not_exists", "/target/link") - h5["broken_link/external_broken_link"] = h5py.ExternalLink(externalFilename, "/target/not_exists") + h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset") + h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link") + h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link") + h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists") h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists") h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists") + layout = h5py.VirtualLayout((2,2), dtype=int) + layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int) + layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int) + h5.create_group("/ext") + h5["/ext"].create_virtual_dataset("virtual", layout) + external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)] + h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external) h5.close() return filename @@ -640,7 +652,7 @@ class TestH5Node(TestCaseQt): cls.qWaitForDestroy(ref) cls.h5File.close() shutil.rmtree(cls.tmpDirectory) - super(TestH5Node, cls).tearDownClass() + super(_TestModelBase, cls).tearDownClass() def getIndexFromPath(self, model, path): """ @@ -658,9 +670,114 @@ class TestH5Node(TestCaseQt): raise RuntimeError("Path not found") return index - def getH5NodeFromPath(self, model, path): + def getH5ItemFromPath(self, model, path): index = self.getIndexFromPath(model, path) - item = model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE) + return model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE) + + +class TestH5Item(_TestModelBase): + + def testFile(self): + path = ["base.h5"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "") + + def testGroup(self): + path = ["base.h5", "group"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "") + + def testDataset(self): + path = ["base.h5", "group", "dataset"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "") + + def testSoftLink(self): + path = ["base.h5", "link", "soft_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft") + + def testSoftLinkToLink(self): + path = ["base.h5", "link", "soft_link_to_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft") + + def testSoftLinkRelative(self): + path = ["base.h5", "group", "soft_link_relative"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft") + + def testExternalLink(self): + path = ["base.h5", "link", "external_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External") + + def testExternalLinkToLink(self): + path = ["base.h5", "link", "external_link_to_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External") + + def testExternalBrokenFile(self): + path = ["base.h5", "broken_link", "external_broken_file"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External") + + def testExternalBrokenLink(self): + path = ["base.h5", "broken_link", "external_broken_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External") + + def testSoftBrokenLink(self): + path = ["base.h5", "broken_link", "soft_broken_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft") + + def testSoftLinkToBrokenLink(self): + path = ["base.h5", "broken_link", "soft_link_to_broken_link"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft") + + def testDatasetFromSoftLinkToGroup(self): + path = ["base.h5", "link", "soft_link_to_group", "dataset"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "") + + def testDatasetFromSoftLinkToFile(self): + path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "") + + def testExternalVirtual(self): + path = ["base.h5", "ext", "virtual"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Virtual") + + def testExternalRaw(self): + path = ["base.h5", "ext", "raw"] + h5item = self.getH5ItemFromPath(self.model, path) + + self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "ExtRaw") + + +class TestH5Node(_TestModelBase): + + def getH5NodeFromPath(self, model, path): + item = self.getH5ItemFromPath(model, path) h5node = hdf5.H5Node(item) return h5node @@ -824,6 +941,28 @@ class TestH5Node(TestCaseQt): self.assertEqual(h5node.local_basename, "dataset") self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset") + def testExternalVirtual(self): + path = ["base.h5", "ext", "virtual"] + h5node = self.getH5NodeFromPath(self.model, path) + + self.assertEqual(h5node.physical_filename, h5node.local_filename) + self.assertIn("base.h5", h5node.physical_filename) + self.assertEqual(h5node.physical_basename, "virtual") + self.assertEqual(h5node.physical_name, "/ext/virtual") + self.assertEqual(h5node.local_basename, "virtual") + self.assertEqual(h5node.local_name, "/ext/virtual") + + def testExternalRaw(self): + path = ["base.h5", "ext", "raw"] + h5node = self.getH5NodeFromPath(self.model, path) + + self.assertEqual(h5node.physical_filename, h5node.local_filename) + self.assertIn("base.h5", h5node.physical_filename) + self.assertEqual(h5node.physical_basename, "raw") + self.assertEqual(h5node.physical_name, "/ext/raw") + self.assertEqual(h5node.local_basename, "raw") + self.assertEqual(h5node.local_name, "/ext/raw") + class TestHdf5TreeView(TestCaseQt): """Test to check that icons module.""" @@ -993,6 +1132,7 @@ def suite(): test_suite.addTest(loadTests(TestNexusSortFilterProxyModel)) test_suite.addTest(loadTests(TestHdf5TreeView)) test_suite.addTest(loadTests(TestH5Node)) + test_suite.addTest(loadTests(TestH5Item)) return test_suite diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py index 2b4677b..eff7689 100644 --- a/silx/gui/plot/ColorBar.py +++ b/silx/gui/plot/ColorBar.py @@ -142,11 +142,8 @@ class ColorBarWidget(qt.QWidget): self._isConnected = True def setVisible(self, isVisible): - # isHidden looks to be always synchronized, while isVisible is not - wasHidden = self.isHidden() qt.QWidget.setVisible(self, isVisible) - if wasHidden != self.isHidden(): - self.sigVisibleChanged.emit(not self.isHidden()) + self.sigVisibleChanged.emit(isVisible) def showEvent(self, event): self._connectPlot() diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py index cd891cc..dc6bf63 100644 --- a/silx/gui/plot/ComplexImageView.py +++ b/silx/gui/plot/ComplexImageView.py @@ -318,7 +318,7 @@ class ComplexImageView(qt.QWidget): False to use provided data (do not modify!). """ if data is None: - data = numpy.zeros((0, 0), dtype=numpy.complex) + data = numpy.zeros((0, 0), dtype=numpy.complex64) previousData = self._plotImage.getComplexData(copy=False) diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py index 4865b8e..5c9033e 100644 --- a/silx/gui/plot/CurvesROIWidget.py +++ b/silx/gui/plot/CurvesROIWidget.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2019 European Synchrotron Radiation Facility +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -1215,14 +1215,14 @@ class ROI(_RegionOfInterestBase): if len(idx): xw = x[idx] yw = y[idx] - rawCounts = yw.sum(dtype=numpy.float) + rawCounts = yw.sum(dtype=numpy.float64) deltaX = xw[-1] - xw[0] deltaY = yw[-1] - yw[0] if deltaX > 0.0: slope = (deltaY / deltaX) background = yw[0] + slope * (xw - xw[0]) netCounts = (rawCounts - - background.sum(dtype=numpy.float)) + background.sum(dtype=numpy.float64)) else: netCounts = 0.0 else: diff --git a/silx/gui/plot/ImageStack.py b/silx/gui/plot/ImageStack.py index c620d6d..3b652ca 100644 --- a/silx/gui/plot/ImageStack.py +++ b/silx/gui/plot/ImageStack.py @@ -150,7 +150,10 @@ class UrlList(qt.QWidget): self._listWidget.addItems(url_names) def _notifyCurrentUrlChanged(self, current, previous): - self.sigCurrentUrlChanged.emit(current.text()) + if current is None: + pass + else: + self.sigCurrentUrlChanged.emit(current.text()) def setUrl(self, url: DataUrl) -> None: assert isinstance(url, DataUrl) @@ -163,6 +166,9 @@ class UrlList(qt.QWidget): self._listWidget.setCurrentItem(item) self.sigCurrentUrlChanged.emit(item.text()) + def clear(self): + self._listWidget.clear() + class _ToggleableUrlSelectionTable(qt.QWidget): @@ -214,6 +220,9 @@ class _ToggleableUrlSelectionTable(qt.QWidget): def _propagateSignal(self, url): self.sigCurrentUrlChanged.emit(url) + def clear(self): + self._urlsTable.clear() + class UrlLoader(qt.QThread): """ @@ -326,6 +335,8 @@ class ImageStack(qt.QMainWindow): self._urlData = OrderedDict({}) self._current_url = None self._plot.clear() + self._urlsTable.clear() + self._slider.setMaximum(-1) def _preFetch(self, urls: list) -> None: """Pre-fetch the given urls if necessary @@ -414,14 +425,16 @@ class ImageStack(qt.QMainWindow): self._urlsTable.blockSignals(old_url_table) old_slider = self._slider.blockSignals(True) + self._slider.setMinimum(0) self._slider.setMaximum(len(self._urls) - 1) self._slider.blockSignals(old_slider) if self.getCurrentUrl() in self._urls: self.setCurrentUrl(self.getCurrentUrl()) else: - first_url = self._urls[list(self._urls.keys())[0]] - self.setCurrentUrl(first_url) + if len(self._urls.keys()) > 0: + first_url = self._urls[list(self._urls.keys())[0]] + self.setCurrentUrl(first_url) def getUrls(self) -> tuple: """ @@ -516,7 +529,11 @@ class ImageStack(qt.QMainWindow): :param index: url to be displayed :type: int """ - if index >= len(self._urls): + if index < 0: + return + if self._urls is None: + return + elif index >= len(self._urls): raise ValueError('requested index out of bounds') else: return self.setCurrentUrl(self._urls[index]) diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py index fafd49f..8cc0cc6 100644 --- a/silx/gui/plot/ImageView.py +++ b/silx/gui/plot/ImageView.py @@ -56,7 +56,7 @@ from ..colors import Colormap from ..colors import cursorColorForColormap from .tools import LimitsToolBar from .Profile import ProfileToolBar - +from ...utils.proxy import docstring _logger = logging.getLogger(__name__) @@ -341,6 +341,10 @@ class ImageView(PlotWindow): self._radarView = RadarView(parent=self) self._radarView.visibleRectDragged.connect(self._radarViewCB) + self.__setCentralWidget() + + def __setCentralWidget(self): + """Set central widget with all its content""" layout = qt.QGridLayout() layout.addWidget(self.getWidgetHandle(), 0, 0) layout.addWidget(self._histoVPlot.getWidgetHandle(), 0, 1) @@ -365,6 +369,12 @@ class ImageView(PlotWindow): centralWidget.setLayout(layout) self.setCentralWidget(centralWidget) + @docstring(PlotWidget) + def setBackend(self, backend): + # Use PlotWidget here since we override PlotWindow behavior + PlotWidget.setBackend(self, backend) + self.__setCentralWidget() + def _dirtyCache(self): self._cache = None diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py index a95e277..8ff8641 100644 --- a/silx/gui/plot/MaskToolsWidget.py +++ b/silx/gui/plot/MaskToolsWidget.py @@ -116,7 +116,8 @@ class ImageMask(BaseMask): """ if kind == 'edf': edfFile = EdfFile(filename, access="w+") - edfFile.WriteImage({}, self.getMask(copy=False), Append=0) + header = {"program_name": "silx-mask", "masked_value": "nonzero"} + edfFile.WriteImage(header, self.getMask(copy=False), Append=0) elif kind == 'tif': tiffFile = TiffIO(filename, mode='w') @@ -568,7 +569,9 @@ class MaskToolsWidget(BaseMaskToolsWidget): filename = dialog.selectedFiles()[0] dialog.close() + # Update the directory according to the user selection self.maskFileDir = os.path.dirname(filename) + try: self.load(filename) except RuntimeWarning as e: @@ -660,22 +663,35 @@ class MaskToolsWidget(BaseMaskToolsWidget): if os.path.exists(filename) and "HDF5" not in nameFilter: try: os.remove(filename) - except IOError: + except IOError as e: msg = qt.QMessageBox(self) + msg.setWindowTitle("Removing existing file") msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] msg.setText("Cannot save.\n" - "Input Output Error: %s" % (sys.exc_info()[1])) + "Input Output Error: %s" % strerror) msg.exec_() return + # Update the directory according to the user selection self.maskFileDir = os.path.dirname(filename) + try: self.save(filename, extension[1:]) except Exception as e: - raise msg = qt.QMessageBox(self) + msg.setWindowTitle("Saving mask file") msg.setIcon(qt.QMessageBox.Critical) - msg.setText("Cannot save file %s\n%s" % (filename, e.args[0])) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save file %s\n%s" % (filename, strerror)) msg.exec_() def resetSelectionMask(self): @@ -727,7 +743,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): # Convert from plot to array coords center = (event['points'][0] - self._origin) / self._scale size = event['points'][1] / self._scale - center = center.astype(numpy.int) # (row, col) + center = center.astype(numpy.int64) # (row, col) self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask) self._mask.commit() @@ -736,7 +752,7 @@ class MaskToolsWidget(BaseMaskToolsWidget): doMask = self._isMasking() # Convert from plot to array coords vertices = (event['points'] - self._origin) / self._scale - vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col) + vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col) self._mask.updatePolygon(level, vertices, doMask) self._mask.commit() diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py index d182a49..cfe140b 100644 --- a/silx/gui/plot/PlotInteraction.py +++ b/silx/gui/plot/PlotInteraction.py @@ -1604,6 +1604,8 @@ class DrawSelectMode(FocusManager): def __init__(self, plot, shape, label, color, width): eventHandlerClass = _DRAW_MODES[shape] + self._pan = Pan(plot) + self._panStart = None parameters = { 'shape': shape, 'label': label, @@ -1614,6 +1616,23 @@ class DrawSelectMode(FocusManager): ItemsInteractionForCombo(plot), eventHandlerClass(plot, parameters))) + def handleEvent(self, eventName, *args, **kwargs): + # Hack to add pan interaction to select-draw + # See issue Refactor PlotWidget interaction #3292 + if eventName == 'press' and args[2] == MIDDLE_BTN: + self._panStart = args[:2] + self._pan.beginDrag(*args) + return # Consume middle click events + elif eventName == 'release' and args[2] == MIDDLE_BTN: + self._panStart = None + self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN) + return # Consume middle click events + elif self._panStart is not None and eventName == 'move': + x, y = args[:2] + self._pan.drag(x, y, MIDDLE_BTN) + + super().handleEvent(eventName, *args, **kwargs) + def getDescription(self): """Returns the dict describing this interactive mode""" params = self.eventHandlers[1].parameters.copy() diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py index 9f9f846..23b7fe9 100755 --- a/silx/gui/plot/PlotWidget.py +++ b/silx/gui/plot/PlotWidget.py @@ -52,7 +52,7 @@ from silx.utils.property import classproperty from silx.utils.deprecation import deprecated, deprecated_warning try: # Import matplotlib now to init matplotlib our way - from . import matplotlib + import silx.gui.utils.matplotlib # noqa except ImportError: _logger.debug("matplotlib not available") @@ -205,6 +205,12 @@ class PlotWidget(qt.QMainWindow): It provides the visible state. """ + _sigDefaultContextMenu = qt.Signal(qt.QMenu) + """Signal emitted when the default context menu of the plot is feed. + + It provides the menu which will be displayed. + """ + def __init__(self, parent=None, backend=None): self._autoreplot = False self._dirty = False @@ -222,8 +228,6 @@ class PlotWidget(qt.QMainWindow): self.setWindowTitle('PlotWidget') # Init the backend - if backend is None: - backend = silx.config.DEFAULT_PLOT_BACKEND self._backend = self.__getBackendClass(backend)(self, self) self.setCallback() # set _callback @@ -259,6 +263,12 @@ class PlotWidget(qt.QMainWindow): self._grid = None self._graphTitle = '' + self.__graphCursorShape = 'default' + + # Set axes margins + self.__axesDisplayed = True + self.__axesMargins = 0., 0., 0., 0. + self.setAxesMargins(.15, .1, .1, .15) self.setGraphTitle() self.setGraphXLabel() @@ -314,6 +324,9 @@ class PlotWidget(qt.QMainWindow): :raise ValueError: In case the backend is not supported :raise RuntimeError: If a backend is not available """ + if backend is None: + backend = silx.config.DEFAULT_PLOT_BACKEND + if callable(backend): return backend @@ -375,6 +388,98 @@ class PlotWidget(qt.QMainWindow): """ silx.config.DEFAULT_PLOT_BACKEND = backend + def setBackend(self, backend): + """Set the backend to use for rendering. + + Supported backends: + + - 'matplotlib' and 'mpl': Matplotlib with Qt. + - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1) + - 'none': No backend, to run headless for testing purpose. + + :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend: + The backend to use, in: + 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none', + a :class:`BackendBase.BackendBase` class. + If multiple backends are provided, the first available one is used. + :raises ValueError: Unsupported backend descriptor + :raises RuntimeError: Error while loading a backend + """ + backend = self.__getBackendClass(backend)(self, self) + + # First save state that is stored in the backend + xaxis = self.getXAxis() + xmin, xmax = xaxis.getLimits() + ymin, ymax = self.getYAxis(axis='left').getLimits() + y2min, y2max = self.getYAxis(axis='right').getLimits() + isKeepDataAspectRatio = self.isKeepDataAspectRatio() + xTimeZone = xaxis.getTimeZone() + isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES + + isYAxisInverted = self.getYAxis().isInverted() + + # Remove all items from previous backend + for item in self.getItems(): + item._removeBackendRenderer(self._backend) + + # Switch backend + self._backend = backend + widget = self._backend.getWidgetHandle() + self.setCentralWidget(widget) + if widget is None: + _logger.info("PlotWidget backend does not support widget") + + # Mark as newly dirty + self._dirty = False + self._setDirtyPlot() + + # Synchronize/restore state + self._foregroundColorsUpdated() + self._backgroundColorsUpdated() + + self._backend.setGraphCursorShape(self.getGraphCursorShape()) + crosshairConfig = self.getGraphCursor() + if crosshairConfig is None: + self._backend.setGraphCursor(False, 'black', 1, '-') + else: + self._backend.setGraphCursor(True, *crosshairConfig) + + self._backend.setGraphTitle(self.getGraphTitle()) + self._backend.setGraphGrid(self.getGraphGrid()) + if self.isAxesDisplayed(): + self._backend.setAxesMargins(*self.getAxesMargins()) + else: + self._backend.setAxesMargins(0., 0., 0., 0.) + + # Set axes + xaxis = self.getXAxis() + self._backend.setGraphXLabel(xaxis.getLabel()) + self._backend.setXAxisTimeZone(xTimeZone) + self._backend.setXAxisTimeSeries(isXAxisTimeSeries) + self._backend.setXAxisLogarithmic( + xaxis.getScale() == items.Axis.LOGARITHMIC) + + for axis in ('left', 'right'): + self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis) + self._backend.setYAxisInverted(isYAxisInverted) + self._backend.setYAxisLogarithmic( + self.getYAxis().getScale() == items.Axis.LOGARITHMIC) + + # Finally restore aspect ratio and limits + self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio) + self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max) + + # Mark all items for update with new backend + for item in self.getItems(): + item._updated() + + def getBackend(self): + """Returns the backend currently used by :class:`PlotWidget`. + + :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase + """ + return self._backend + def _getDirtyPlot(self): """Return the plot dirty flag. @@ -403,6 +508,8 @@ class PlotWidget(qt.QMainWindow): action = ClosePolygonInteractionAction(plot=self, parent=menu) menu.addAction(action) + self._sigDefaultContextMenu.emit(menu) + # Make sure the plot is updated, especially when the plot is in # draw interaction mode menu.aboutToHide.connect(self.__simulateMouseMove) @@ -538,6 +645,16 @@ class PlotWidget(qt.QMainWindow): self._dataBackgroundColor = color self._backgroundColorsUpdated() + dataBackgroundColor = qt.Property( + qt.QColor, getDataBackgroundColor, setDataBackgroundColor + ) + + backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor) + + foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor) + + gridColor = qt.Property(qt.QColor, getGridColor, setGridColor) + def showEvent(self, event): if self._autoreplot and self._dirty: self._backend.postRedisplay() @@ -2405,18 +2522,61 @@ class PlotWidget(qt.QMainWindow): assert(axis in ["left", "right"]) return self._yAxis if axis == "left" else self._yRightAxis - def setAxesDisplayed(self, displayed): + def setAxesDisplayed(self, displayed: bool): """Display or not the axes. :param bool displayed: If `True` axes are displayed. If `False` axes are not anymore visible and the margin used for them is removed. """ - self._backend.setAxesDisplayed(displayed) - self._setDirtyPlot() - self._sigAxesVisibilityChanged.emit(displayed) + if displayed != self.__axesDisplayed: + self.__axesDisplayed = displayed + if displayed: + self._backend.setAxesMargins(*self.__axesMargins) + else: + self._backend.setAxesMargins(0., 0., 0., 0.) + self._setDirtyPlot() + self._sigAxesVisibilityChanged.emit(displayed) + + def isAxesDisplayed(self) -> bool: + """Returns whether or not axes are currently displayed + + :rtype: bool + """ + return self.__axesDisplayed + + def setAxesMargins( + self, left: float, top: float, right: float, bottom: float): + """Set ratios of margins surrounding data plot area. + + All ratios must be within [0., 1.]. + Sums of ratios of opposed side must be < 1. + + :param float left: Left-side margin ratio. + :param float top: Top margin ratio + :param float right: Right-side margin ratio + :param float bottom: Bottom margin ratio + :raises ValueError: + """ + for value in (left, top, right, bottom): + if value < 0. or value > 1.: + raise ValueError("Margin ratios must be within [0., 1.]") + if left + right >= 1. or top + bottom >= 1.: + raise ValueError("Sum of ratios of opposed sides >= 1") + margins = left, top, right, bottom + + if margins != self.__axesMargins: + self.__axesMargins = margins + if self.isAxesDisplayed(): # Only apply if axes are displayed + self._backend.setAxesMargins(*margins) + self._setDirtyPlot() - def _isAxesDisplayed(self): - return self._backend.isAxesDisplayed() + def getAxesMargins(self): + """Returns ratio of margins surrounding data plot area. + + :return: (left, top, right, bottom) + :rtype: List[float] + """ + return self.__axesMargins def setYAxisInverted(self, flag=True): """Set the Y axis orientation. @@ -2980,11 +3140,19 @@ class PlotWidget(qt.QMainWindow): # Interaction support + def getGraphCursorShape(self): + """Returns the current cursor shape. + + :rtype: str + """ + return self.__graphCursorShape + def setGraphCursorShape(self, cursor=None): """Set the cursor shape. :param str cursor: Name of the cursor shape """ + self.__graphCursorShape = cursor self._backend.setGraphCursorShape(cursor) @deprecated(replacement='getItems', since_version='0.13') diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py index a3b70c6..3cd605f 100644 --- a/silx/gui/plot/PlotWindow.py +++ b/silx/gui/plot/PlotWindow.py @@ -224,6 +224,56 @@ class PlotWindow(PlotWidget): self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground) self._updateColorBarBackground() + if control: # Create control button only if requested + self.controlButton = qt.QToolButton() + self.controlButton.setText("Options") + self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon) + self.controlButton.setAutoRaise(True) + self.controlButton.setPopupMode(qt.QToolButton.InstantPopup) + menu = qt.QMenu(self) + menu.aboutToShow.connect(self._customControlButtonMenu) + self.controlButton.setMenu(menu) + + self._positionWidget = None + if position: # Add PositionInfo widget to the bottom of the plot + if isinstance(position, abc.Iterable): + # Use position as a set of converters + converters = position + else: + converters = None + self._positionWidget = tools.PositionInfo( + plot=self, converters=converters) + # Set a snapping mode that is consistent with legacy one + self._positionWidget.setSnappingMode( + tools.PositionInfo.SNAPPING_CROSSHAIR | + tools.PositionInfo.SNAPPING_ACTIVE_ONLY | + tools.PositionInfo.SNAPPING_SYMBOLS_ONLY | + tools.PositionInfo.SNAPPING_CURVE | + tools.PositionInfo.SNAPPING_SCATTER) + + self.__setCentralWidget() + + # Creating the toolbar also create actions for toolbuttons + self._interactiveModeToolBar = tools.InteractiveModeToolBar( + parent=self, plot=self) + self.addToolBar(self._interactiveModeToolBar) + + self._toolbar = self._createToolBar(title='Plot', parent=self) + self.addToolBar(self._toolbar) + + self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) + self._outputToolBar.getCopyAction().setVisible(copy) + self._outputToolBar.getSaveAction().setVisible(save) + self._outputToolBar.getPrintAction().setVisible(print_) + self.addToolBar(self._outputToolBar) + + # Activate shortcuts in PlotWindow widget: + for toolbar in (self._interactiveModeToolBar, self._outputToolBar): + for action in toolbar.actions(): + self.addAction(action) + + def __setCentralWidget(self): + """Set central widget to host plot backend, colorbar, and bottom bar""" gridLayout = qt.QGridLayout() gridLayout.setSpacing(0) gridLayout.setContentsMargins(0, 0, 0, 0) @@ -233,42 +283,15 @@ class PlotWindow(PlotWidget): gridLayout.setColumnStretch(0, 1) centralWidget = qt.QWidget(self) centralWidget.setLayout(gridLayout) - self.setCentralWidget(centralWidget) - self._positionWidget = None - - if control or position: + if hasattr(self, "controlButton") or self._positionWidget is not None: hbox = qt.QHBoxLayout() hbox.setContentsMargins(0, 0, 0, 0) - if control: - self.controlButton = qt.QToolButton() - self.controlButton.setText("Options") - self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon) - self.controlButton.setAutoRaise(True) - self.controlButton.setPopupMode(qt.QToolButton.InstantPopup) - menu = qt.QMenu(self) - menu.aboutToShow.connect(self._customControlButtonMenu) - self.controlButton.setMenu(menu) - + if hasattr(self, "controlButton"): hbox.addWidget(self.controlButton) - if position: # Add PositionInfo widget to the bottom of the plot - if isinstance(position, abc.Iterable): - # Use position as a set of converters - converters = position - else: - converters = None - self._positionWidget = tools.PositionInfo( - plot=self, converters=converters) - # Set a snapping mode that is consistent with legacy one - self._positionWidget.setSnappingMode( - tools.PositionInfo.SNAPPING_CROSSHAIR | - tools.PositionInfo.SNAPPING_ACTIVE_ONLY | - tools.PositionInfo.SNAPPING_SYMBOLS_ONLY | - tools.PositionInfo.SNAPPING_CURVE | - tools.PositionInfo.SNAPPING_SCATTER) - + if self._positionWidget is not None: hbox.addWidget(self._positionWidget) hbox.addStretch(1) @@ -277,24 +300,12 @@ class PlotWindow(PlotWidget): gridLayout.addWidget(bottomBar, 1, 0, 1, -1) - # Creating the toolbar also create actions for toolbuttons - self._interactiveModeToolBar = tools.InteractiveModeToolBar( - parent=self, plot=self) - self.addToolBar(self._interactiveModeToolBar) - - self._toolbar = self._createToolBar(title='Plot', parent=self) - self.addToolBar(self._toolbar) - - self._outputToolBar = tools.OutputToolBar(parent=self, plot=self) - self._outputToolBar.getCopyAction().setVisible(copy) - self._outputToolBar.getSaveAction().setVisible(save) - self._outputToolBar.getPrintAction().setVisible(print_) - self.addToolBar(self._outputToolBar) + self.setCentralWidget(centralWidget) - # Activate shortcuts in PlotWindow widget: - for toolbar in (self._interactiveModeToolBar, self._outputToolBar): - for action in toolbar.actions(): - self.addAction(action) + @docstring(PlotWidget) + def setBackend(self, backend): + super(PlotWindow, self).setBackend(backend) + self.__setCentralWidget() # Recreate PlotWindow's central widget @docstring(PlotWidget) def setBackgroundColor(self, color): @@ -313,7 +324,7 @@ class PlotWindow(PlotWidget): def _updateColorBarBackground(self): """Update the colorbar background according to the state of the plot""" - if self._isAxesDisplayed(): + if self.isAxesDisplayed(): color = self.getBackgroundColor() else: color = self.getDataBackgroundColor() diff --git a/silx/gui/plot/ROIStatsWidget.py b/silx/gui/plot/ROIStatsWidget.py new file mode 100644 index 0000000..094d66a --- /dev/null +++ b/silx/gui/plot/ROIStatsWidget.py @@ -0,0 +1,780 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides widget for displaying statistics relative to a +Region of interest and an item +""" + + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "22/07/2019" + + +from contextlib import contextmanager +from silx.gui import qt +from silx.gui import icons +from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container +from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode +from silx.gui.widgets.TableWidget import TableWidget +from silx.gui.plot.items.roi import RegionOfInterest +from silx.gui.plot import items as plotitems +from silx.gui.plot.items.core import ItemChangedType +from silx.gui.plot3d import items as plot3ditems +from silx.gui.plot.CurvesROIWidget import ROI +from silx.gui.plot import stats as statsmdl +from collections import OrderedDict +from silx.utils.proxy import docstring +import silx.gui.plot.items.marker +import silx.gui.plot.items.shape +import functools +import logging + +_logger = logging.getLogger(__name__) + + +class _GetROIItemCoupleDialog(qt.QDialog): + """ + Dialog used to know which plot item and which roi he wants + """ + _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram') + + def __init__(self, parent=None, plot=None, rois=None): + qt.QDialog.__init__(self, parent=parent) + assert plot is not None + assert rois is not None + self._plot = plot + self._rois = rois + + self.setLayout(qt.QVBoxLayout()) + + # define the selection widget + self._selection_widget = qt.QWidget() + self._selection_widget.setLayout(qt.QHBoxLayout()) + self._kindCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._kindCB) + self._itemCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._itemCB) + self._roiCB = qt.QComboBox(parent=self) + self._selection_widget.layout().addWidget(self._roiCB) + self.layout().addWidget(self._selection_widget) + + # define modal buttons + types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel + self._buttonsModal = qt.QDialogButtonBox(parent=self) + self._buttonsModal.setStandardButtons(types) + self.layout().addWidget(self._buttonsModal) + self._buttonsModal.accepted.connect(self.accept) + self._buttonsModal.rejected.connect(self.reject) + + # connect signal / slot + self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi) + + def _getCompatibleRois(self, kind): + """Return compatible rois for the given item kind""" + def is_compatible(roi, kind): + if isinstance(roi, RegionOfInterest): + return kind in ('image', 'scatter') + elif isinstance(roi, ROI): + return kind in ('curve', 'histogram') + else: + raise ValueError('kind not managed') + return list(filter(lambda x: is_compatible(x, kind), self._rois)) + + def exec_(self): + self._kindCB.clear() + self._itemCB.clear() + # filter kind without any items + self._valid_kinds = {} + # key is item type, value kinds + self._valid_rois = {} + # key is item type, value rois + self._kind_name_to_roi = {} + # key is (kind, roi name) value is roi + self._kind_name_to_item = {} + # key is (kind, legend name) value is item + for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS: + def getItems(kind): + output = [] + for item in self._plot.getItems(): + type_ = self._plot._itemKind(item) + if type_ in kind and item.isVisible(): + output.append(item) + return output + + items = getItems(kind=kind) + rois = self._getCompatibleRois(kind=kind) + if len(items) > 0 and len(rois) > 0: + self._valid_kinds[kind] = items + self._valid_rois[kind] = rois + for roi in rois: + name = roi.getName() + self._kind_name_to_roi[(kind, name)] = roi + for item in items: + self._kind_name_to_item[(kind, item.getLegend())] = item + + # filter roi according to kinds + if len(self._valid_kinds) == 0: + _logger.warning('no couple item/roi detected for displaying stats') + return self.reject() + + for kind in self._valid_kinds: + self._kindCB.addItem(kind) + self._updateValidItemAndRoi() + + return qt.QDialog.exec_(self) + + def _updateValidItemAndRoi(self, *args, **kwargs): + self._itemCB.clear() + self._roiCB.clear() + kind = self._kindCB.currentText() + for roi in self._valid_rois[kind]: + self._roiCB.addItem(roi.getName()) + for item in self._valid_kinds[kind]: + self._itemCB.addItem(item.getLegend()) + + def getROI(self): + kind = self._kindCB.currentText() + roi_name = self._roiCB.currentText() + return self._kind_name_to_roi[(kind, roi_name)] + + def getItem(self): + kind = self._kindCB.currentText() + item_name = self._itemCB.currentText() + return self._kind_name_to_item[(kind, item_name)] + + +class ROIStatsItemHelper(object): + """Item utils to associate a plot item and a roi + + Display on one row statistics regarding the couple + (Item (plot item) / roi). + + :param Item plot_item: item for which we want statistics + :param Union[ROI,RegionOfInterest]: region of interest to use for + statistics. + """ + def __init__(self, plot_item, roi): + self._plot_item = plot_item + self._roi = roi + + @property + def roi(self): + """roi""" + return self._roi + + def roi_name(self): + if isinstance(self._roi, ROI): + return self._roi.getName() + elif isinstance(self._roi, RegionOfInterest): + return self._roi.getName() + else: + raise TypeError('Unmanaged roi type') + + @property + def roi_kind(self): + """roi class""" + return self._roi.__class__ + + # TODO: should call a util function from the wrapper ? + def item_kind(self): + """item kind""" + if isinstance(self._plot_item, plotitems.Curve): + return 'curve' + elif isinstance(self._plot_item, plotitems.ImageData): + return 'image' + elif isinstance(self._plot_item, plotitems.Scatter): + return 'scatter' + elif isinstance(self._plot_item, plotitems.Histogram): + return 'histogram' + elif isinstance(self._plot_item, (plot3ditems.ImageData, + plot3ditems.ScalarField3D)): + return 'image' + elif isinstance(self._plot_item, (plot3ditems.Scatter2D, + plot3ditems.Scatter3D)): + return 'scatter' + + @property + def item_legend(self): + """legend of the plot Item""" + return self._plot_item.getLegend() + + def id_key(self): + """unique key to represent the couple (item, roi)""" + return (self.item_kind(), self.item_legend, self.roi_kind, + self.roi_name()) + + +class _StatsROITable(_StatsWidgetBase, TableWidget): + """ + Table sued to display some statistics regarding a couple (item/roi) + """ + _LEGEND_HEADER_DATA = 'legend' + + _KIND_HEADER_DATA = 'kind' + + _ROI_HEADER_DATA = 'roi' + + sigUpdateModeChanged = qt.Signal(object) + """Signal emitted when the update mode changed""" + + def __init__(self, parent, plot): + TableWidget.__init__(self, parent) + _StatsWidgetBase.__init__(self, statsOnVisibleData=False, + displayOnlyActItem=False) + self.__region_edition_callback = {} + """We need to keep trace of the roi signals connection because + the roi emits the sigChanged during roi edition""" + self._items = {} + self.setRowCount(0) + self.setColumnCount(3) + + # Init headers + headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) + self.setHorizontalHeaderItem(0, headerItem) + headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) + self.setHorizontalHeaderItem(1, headerItem) + headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title()) + headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA) + self.setHorizontalHeaderItem(2, headerItem) + + self.setSortingEnabled(True) + self.setPlot(plot) + + self.__plotItemToItems = {} + """Key is plotItem, values is list of __RoiStatsItemWidget""" + self.__roiToItems = {} + """Key is roi, values is list of __RoiStatsItemWidget""" + self.__roisKeyToRoi = {} + + def add(self, item): + assert isinstance(item, ROIStatsItemHelper) + if item.id_key() in self._items: + _logger.warning(item.id_key(), 'is already present') + return None + self._items[item.id_key()] = item + self._addItem(item) + return item + + def _addItem(self, item): + """ + Add a _RoiStatsItemWidget item to the table. + + :param item: + :return: True if successfully added. + """ + if not isinstance(item, ROIStatsItemHelper): + # skipped because also receive all new plot item (Marker...) that + # we don't want to manage in this case. + return + # plotItem = item.getItem() + # roi = item.getROI() + kind = item.item_kind() + if kind not in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.info("Item has not a supported type: %s", item) + return False + + # register the roi and the kind + self._registerPlotItem(item) + self._registerROI(item) + + # Prepare table items + tableItems = [ + qt.QTableWidgetItem(), # Legend + qt.QTableWidgetItem(), # Kind + qt.QTableWidgetItem()] # roi + + for column in range(3, self.columnCount()): + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + + formatter = self._statsHandler.formatters[name] + if formatter: + tableItem = formatter.tabWidgetItemClass() + else: + tableItem = qt.QTableWidgetItem() + + tooltip = self._statsHandler.stats[name].getToolTip(kind=kind) + if tooltip is not None: + tableItem.setToolTip(tooltip) + + tableItems.append(tableItem) + + # Disable sorting while adding table items + with self._disableSorting(): + # Add a row to the table + self.setRowCount(self.rowCount() + 1) + + # Add table items to the last row + row = self.rowCount() - 1 + for column, tableItem in enumerate(tableItems): + tableItem.setData(qt.Qt.UserRole, _Container(item)) + tableItem.setFlags( + qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable) + self.setItem(row, column, tableItem) + + # Update table items content + self._updateStats(item, data_changed=True) + + # Listen for item changes + # Using queued connection to avoid issue with sender + # being that of the signal calling the signal + item._plot_item.sigItemChanged.connect(self._plotItemChanged, + qt.Qt.QueuedConnection) + return True + + def _removeAllItems(self): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + # item = self._tableItemToItem(tableItem) + # item.sigItemChanged.disconnect(self._plotItemChanged) + self.clearContents() + self.setRowCount(0) + + def clear(self): + self._removeAllItems() + + def setStats(self, statsHandler): + """Set which stats to display and the associated formatting. + + :param StatsHandler statsHandler: + Set the statistics to be displayed and how to format them using + """ + self._removeAllItems() + _StatsWidgetBase.setStats(self, statsHandler) + + self.setRowCount(0) + self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa + + for index, stat in enumerate(self._statsHandler.stats.values()): + headerItem = qt.QTableWidgetItem(stat.name.capitalize()) + headerItem.setData(qt.Qt.UserRole, stat.name) + if stat.description is not None: + headerItem.setToolTip(stat.description) + self.setHorizontalHeaderItem(3 + index, headerItem) + + horizontalHeader = self.horizontalHeader() + if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5 + horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents) + else: # Qt4 + horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents) + + self._updateItemObserve() + + def _updateItemObserve(self, *args): + pass + + def _dataChanged(self, item): + pass + + def _updateStats(self, item, data_changed=False, roi_changed=False): + assert isinstance(item, ROIStatsItemHelper) + plotItem = item._plot_item + roi = item._roi + if item is None: + return + plot = self.getPlot() + if plot is None: + _logger.info("Plot not available") + return + + row = self._itemToRow(item) + if row is None: + _logger.error("This item is not in the table: %s", str(item)) + return + + statsHandler = self.getStatsHandler() + if statsHandler is not None: + stats = statsHandler.calculate(plotItem, plot, + onlimits=self._statsOnVisibleData, + roi=roi, data_changed=data_changed, + roi_changed=roi_changed) + else: + stats = {} + + with self._disableSorting(): + for name, tableItem in self._itemToTableItems(item).items(): + if name == self._LEGEND_HEADER_DATA: + text = self._plotWrapper.getLabel(plotItem) + tableItem.setText(text) + elif name == self._KIND_HEADER_DATA: + tableItem.setText(self._plotWrapper.getKind(plotItem)) + elif name == self._ROI_HEADER_DATA: + name = roi.getName() + tableItem.setText(name) + else: + value = stats.get(name) + if value is None: + _logger.error("Value not found for: %s", name) + tableItem.setText('-') + else: + tableItem.setText(str(value)) + + @contextmanager + def _disableSorting(self): + """Context manager that disables table sorting + + Previous state is restored when leaving + """ + sorting = self.isSortingEnabled() + if sorting: + self.setSortingEnabled(False) + yield + if sorting: + self.setSortingEnabled(sorting) + + def _itemToRow(self, item): + """Find the row corresponding to a plot item + + :param item: The plot item + :return: The corresponding row index + :rtype: Union[int,None] + """ + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + if self._tableItemToItem(tableItem) == item: + return row + return None + + def _tableItemToItem(self, tableItem): + """Find the plot item corresponding to a table item + + :param QTableWidgetItem tableItem: + :rtype: QObject + """ + container = tableItem.data(qt.Qt.UserRole) + return container() + + def _itemToTableItems(self, item): + """Find all table items corresponding to a plot item + + :param item: The plot item + :return: An ordered dict of column name to QTableWidgetItem mapping + for the given plot item. + :rtype: OrderedDict + """ + result = OrderedDict() + row = self._itemToRow(item) + if row is not None: + for column in range(self.columnCount()): + tableItem = self.item(row, column) + if self._tableItemToItem(tableItem) != item: + _logger.error("Table item/plot item mismatch") + else: + header = self.horizontalHeaderItem(column) + name = header.data(qt.Qt.UserRole) + result[name] = tableItem + return result + + def _plotItemToItems(self, plotItem): + """Return all _RoiStatsItemWidget associated to the plotItem + Needed for updating on itemChanged signal + """ + if plotItem in self.__plotItemToItems: + return [] + else: + return self.__plotItemToItems[plotItem] + + def _registerPlotItem(self, item): + if item._plot_item not in self.__plotItemToItems: + self.__plotItemToItems[item._plot_item] = set() + self.__plotItemToItems[item._plot_item].add(item) + + def _roiToItems(self, roi): + """Return all _RoiStatsItemWidget associated to the roi + Needed for updating on roiChanged signal + """ + if roi in self.__roiToItems: + return [] + else: + return self.__roiToItems[roi] + + def _registerROI(self, item): + if item._roi not in self.__roiToItems: + self.__roiToItems[item._roi] = set() + # TODO: normalize also sig name + if isinstance(item._roi, RegionOfInterest): + # item connection within sigRegionChanged should only be + # stopped during the region edition + self.__region_edition_callback[item._roi] = functools.partial( + self._updateAllStats, False, True) + item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi]) + item._roi.sigEditingStarted.connect(functools.partial( + self._startFiltering, item._roi)) + item._roi.sigEditingFinished.connect(functools.partial( + self._endFiltering, item._roi)) + else: + item._roi.sigChanged.connect(functools.partial( + self._updateAllStats, False, True)) + self.__roiToItems[item._roi].add(item) + + def _startFiltering(self, roi): + roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi]) + + def _endFiltering(self, roi): + roi.sigRegionChanged.connect(self.__region_edition_callback[roi]) + self._updateAllStats(roi_changed=True) + + def unregisterROI(self, roi): + if roi in self.__roiToItems: + del self.__roiToItems[roi] + if isinstance(roi, RegionOfInterest): + roi.sigRegionEditionStarted.disconnect(functools.partial( + self._startFiltering, roi)) + roi.sigRegionEditionFinished.disconnect(functools.partial( + self._startFiltering, roi)) + try: + roi.sigRegionChanged.disconnect(self._updateAllStats) + except: + pass + else: + roi.sigChanged.disconnect(self._updateAllStats) + + def _plotItemChanged(self, event): + """Handle modifications of the items. + + :param event: + """ + if event is ItemChangedType.DATA: + if self.getUpdateMode() is UpdateMode.MANUAL: + return + if self._skipPlotItemChangedEvent(event) is True: + return + else: + sender = self.sender() + for item in self.__plotItemToItems[sender]: + # TODO: get all concerned items + self._updateStats(item, data_changed=True) + # deal with stat items visibility + if event is ItemChangedType.VISIBLE: + if len(self._itemToTableItems(item).items()) > 0: + item_0 = list(self._itemToTableItems(item).values())[0] + row_index = item_0.row() + self.setRowHidden(row_index, not item.isVisible()) + + def _removeItem(self, itemKey): + if isinstance(itemKey, (silx.gui.plot.items.marker.Marker, + silx.gui.plot.items.shape.Shape)): + return + if itemKey not in self._items: + _logger.warning('key not recognized. Won\'t remove any item') + return + item = self._items[itemKey] + row = self._itemToRow(item) + if row is None: + kind = self._plotWrapper.getKind(item) + if kind in statsmdl.BASIC_COMPATIBLE_KINDS: + _logger.error("Removing item that is not in table: %s", str(item)) + return + item._plot_item.sigItemChanged.disconnect(self._plotItemChanged) + self.removeRow(row) + del self._items[itemKey] + + def _updateAllStats(self, is_request=False, roi_changed=False): + """Update stats for all rows in the table + + :param bool is_request: True if come from a manual request + """ + if (self.getUpdateMode() is UpdateMode.MANUAL and + not is_request and not roi_changed): + return + + with self._disableSorting(): + for row in range(self.rowCount()): + tableItem = self.item(row, 0) + item = self._tableItemToItem(tableItem) + self._updateStats(item, roi_changed=roi_changed, + data_changed=is_request) + + def _plotCurrentChanged(self, *args): + pass + + def _getRoi(self, kind, name): + """return the roi fitting the requirement kind, name. This information + is enough to be sure it is unique (in the widget)""" + for roi in self.__roiToItems: + roiName = roi.getName() + if isinstance(roi, kind) and name == roiName: + return roi + return None + + def _getPlotItem(self, kind, legend): + """return the plotItem fitting the requirement kind, legend. + This information is enough to be sure it is unique (in the widget)""" + for plotItem in self.__plotItemToItems: + if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind: + return plotItem + return None + + +class ROIStatsWidget(qt.QMainWindow): + """ + Widget used to define stats item for a couple(roi, plotItem). + Stats will be computing on a given item (curve, image...) in the given + region of interest. + + It also provide an interface for adding and removing items. + + .. snapshotqt:: img/ROIStatsWidget.png + :width: 300px + :align: center + + from silx.gui import qt + from silx.gui.plot import Plot2D + from silx.gui.plot.ROIStatsWidget import ROIStatsWidget + from silx.gui.plot.items.roi import RectangleROI + import numpy + plot = Plot2D() + plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img') + plot.show() + rectangleROI = RectangleROI() + rectangleROI.setGeometry(origin=(0, 100), size=(20, 20)) + rectangleROI.setName('Initial ROI') + widget = ROIStatsWidget(plot=plot) + widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)]) + widget.registerROI(rectangleROI) + widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img')) + widget.show() + + :param Union[qt.QWidget,None] parent: parent qWidget + :param PlotWindow plot: plot widget containing the items + :param stats: stats to display + :param tuple rois: tuple of rois to manage + """ + + def __init__(self, parent=None, plot=None, stats=None, rois=None): + qt.QMainWindow.__init__(self, parent) + + toolbar = qt.QToolBar(self) + icon = icons.getQIcon('add') + self._rois = list(rois) if rois is not None else [] + self._addAction = qt.QAction(icon, 'add item/roi', toolbar) + self._addAction.triggered.connect(self._addRoiStatsItem) + icon = icons.getQIcon('rm') + self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar) + self._removeAction.triggered.connect(self._removeCurrentRow) + + toolbar.addAction(self._addAction) + toolbar.addAction(self._removeAction) + self.addToolBar(toolbar) + + self._plot = plot + self._statsROITable = _StatsROITable(parent=self, plot=self._plot) + self.setStats(stats=stats) + self.setCentralWidget(self._statsROITable) + self.setWindowFlags(qt.Qt.Widget) + + # expose API + self._setUpdateMode = self._statsROITable.setUpdateMode + self._updateAllStats = self._statsROITable._updateAllStats + + # setup + self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows) + + def registerROI(self, roi): + """For now there is no direct link between roi and plot. That is why + we need to add/register them to be able to associate them""" + self._rois.append(roi) + + def setPlot(self, plot): + """Define the plot to interact with + + :param Union[PlotWidget,SceneWidget,None] plot: + The plot containing the items on which statistics are applied + """ + self._plot = plot + + def getPlot(self): + return self._plot + + @docstring(_StatsROITable) + def setStats(self, stats): + if stats is not None: + self._statsROITable.setStats(statsHandler=stats) + + @docstring(_StatsROITable) + def getStatsHandler(self): + """ + + :return: + """ + return self._statsROITable.getStatsHandler() + + def _addRoiStatsItem(self): + """Ask the user what couple ROI / item he want to display""" + dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot, + rois=self._rois) + if dialog.exec_(): + self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem()) + + def addItem(self, plotItem, roi): + """ + Add a row of statitstic regarding the couple (plotItem, roi) + + :param Item plotItem: item to use for statistics + :param roi: region of interest to limit the statistic. + :type: Union[ROI, RegionOfInterest] + :return: None of failed to add the item + :rtype: Union[None,ROIStatsItemHelper] + """ + statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem) + return self._statsROITable.add(item=statsItem) + + def removeItem(self, plotItem, roi): + """ + Remove the row associated to the couple (plotItem, roi) + + :param Item plotItem: item to use for statistics + :param roi: region of interest to limit the statistic. + :type: Union[ROI,RegionOfInterest] + """ + statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem) + self._statsROITable._removeItem(itemKey=statsItem.id_key()) + + def _removeCurrentRow(self): + def is1DKind(kind): + if kind in ('curve', 'histogram', 'scatter'): + return True + else: + return False + + currentRow = self._statsROITable.currentRow() + item_kind = self._statsROITable.item(currentRow, 1).text() + item_legend = self._statsROITable.item(currentRow, 0).text() + + roi_name = self._statsROITable.item(currentRow, 2).text() + roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest + roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name) + if roi is None: + _logger.warning('failed to retrieve the roi you want to remove') + return False + plot_item = self._statsROITable._getPlotItem(kind=item_kind, + legend=item_legend) + if plot_item is None: + _logger.warning('failed to retrieve the plot item you want to' + 'remove') + return False + return self.removeItem(plotItem=plot_item, roi=roi) diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py index 8ff2483..5ae8653 100644 --- a/silx/gui/plot/ScatterMaskToolsWidget.py +++ b/silx/gui/plot/ScatterMaskToolsWidget.py @@ -102,7 +102,7 @@ class ScatterMask(BaseMask): self._mask[indices] = level else: # unmask only where mask level is the specified value - indices_stencil = numpy.zeros_like(self._mask, dtype=numpy.bool) + indices_stencil = numpy.zeros_like(self._mask, dtype=bool) indices_stencil[indices] = True self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0 self._notify() @@ -431,7 +431,9 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): filename = dialog.selectedFiles()[0] dialog.close() + # Update the directory according to the user selection self.maskFileDir = os.path.dirname(filename) + try: self.load(filename) # except RuntimeWarning as e: @@ -475,21 +477,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget): if os.path.exists(filename): try: os.remove(filename) - except IOError: + except IOError as e: msg = qt.QMessageBox(self) + msg.setWindowTitle("Removing existing file") msg.setIcon(qt.QMessageBox.Critical) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] msg.setText("Cannot save.\n" - "Input Output Error: %s" % (sys.exc_info()[1])) + "Input Output Error: %s" % strerror) msg.exec_() return + # Update the directory according to the user selection self.maskFileDir = os.path.dirname(filename) + try: self.save(filename, extension[1:]) except Exception as e: msg = qt.QMessageBox(self) + msg.setWindowTitle("Saving mask file") msg.setIcon(qt.QMessageBox.Critical) - msg.setText("Cannot save file %s\n%s" % (filename, e.args[0])) + + if hasattr(e, "strerror"): + strerror = e.strerror + else: + strerror = sys.exc_info()[1] + msg.setText("Cannot save file %s\n%s" % (filename, strerror)) msg.exec_() def resetSelectionMask(self): diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py index cb7ece1..40e0661 100644 --- a/silx/gui/plot/StackView.py +++ b/silx/gui/plot/StackView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -548,15 +548,8 @@ class StackView(qt.QMainWindow): perspective_changed = True self.setPerspective(perspective) - # This call to setColormap redefines the meaning of autoscale - # for 3D volume: take global min/max rather than frame min/max if self.__autoscaleCmap: - # note: there is no real autoscale in the stack widget, it is more - # like a hack computing stack min and max - colormap = self.getColormap() - _vmin, _vmax = colormap.getColormapRange(data=self._stack) - colormap.setVRange(_vmin, _vmax) - self.setColormap(colormap=colormap) + self.scaleColormapRangeToStack() # init plot self._stackItem.setStackData(self.__transposed_view, 0, copy=False) @@ -791,6 +784,22 @@ class StackView(qt.QMainWindow): # specifying a special colormap return self._plot.getDefaultColormap() + def scaleColormapRangeToStack(self): + """Scale colormap range according to current stack data. + + If no stack has been set through :meth:`setStack`, this has no effect. + + The range scaling mode is given by current :class:`Colormap`'s + :meth:`Colormap.getAutoscaleMode`. + """ + stack = self.getStack(copy=False, returnNumpyArray=True) + if stack is None: + return # No-op + + colormap = self.getColormap() + vmin, vmax = colormap.getColormapRange(data=stack[0]) + colormap.setVRange(vmin=vmin, vmax=vmax) + def setColormap(self, colormap=None, normalization=None, autoscale=None, vmin=None, vmax=None, colors=None): """Set the colormap and update active image. @@ -860,31 +869,14 @@ class StackView(qt.QMainWindow): vmax=vmax, colors=colors) - # Patch: since we don't apply this colormap to a single 2D data but - # a 2D stack we have to deal manually with vmin, vmax - if autoscale is None: - # set default - autoscale = False - elif autoscale and is_dataset(self._stack): - # h5py dataset has no min()/max() methods - raise RuntimeError( - "Cannot auto-scale colormap for a h5py dataset") - else: - autoscale = autoscale - self.__autoscaleCmap = autoscale - - if autoscale and (self._stack is not None): - _vmin, _vmax = _colormap.getColormapRange(data=self._stack) - _colormap.setVRange(vmin=_vmin, vmax=_vmax) - else: - if vmin is None and self._stack is not None: - _colormap.setVMin(self._stack.min()) - else: - _colormap.setVMin(vmin) - if vmax is None and self._stack is not None: - _colormap.setVMax(self._stack.max()) - else: - _colormap.setVMax(vmax) + if autoscale is not None: + deprecated_warning( + type_='function', + name='setColormap', + reason='autoscale argument is replaced by a method', + replacement='scaleColormapRangeToStack', + since_version='0.14') + self.__autoscaleCmap = bool(autoscale) cursorColor = cursorColorForColormap(_colormap.getName()) self._plot.setInteractiveMode('zoom', color=cursorColor) @@ -896,6 +888,12 @@ class StackView(qt.QMainWindow): if isinstance(activeImage, items.ColormapMixIn): activeImage.setColormap(self.getColormap()) + if self.__autoscaleCmap: + # scaleColormapRangeToStack needs to be called **after** + # setDefaultColormap so getColormap returns the right colormap + self.scaleColormapRangeToStack() + + @deprecated(replacement="getPlotWidget", since_version="0.13") def getPlot(self): return self.getPlotWidget() diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py index 6b92ea0..26b48db 100644 --- a/silx/gui/plot/StatsWidget.py +++ b/silx/gui/plot/StatsWidget.py @@ -449,10 +449,12 @@ class _StatsWidgetBase(object): _displayOnlyActItem option.""" raise NotImplementedError('Base class') - def _updateStats(self, item): + def _updateStats(self, item, data_changed=False, roi_changed=False): """Update displayed information for given plot item :param item: The plot item + :param bool data_changed: is the item data changed. + :param bool roi_changed: is the associated roi changed. """ raise NotImplementedError('Base class') @@ -548,7 +550,7 @@ class _StatsWidgetBase(object): class StatsTable(_StatsWidgetBase, TableWidget): """ - TableWidget displaying for each curves contained by the Plot some + TableWidget displaying for each items contained by the Plot some information: * legend @@ -582,10 +584,10 @@ class StatsTable(_StatsWidgetBase, TableWidget): self.setColumnCount(2) # Init headers - headerItem = qt.QTableWidgetItem('Legend') + headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title()) headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA) self.setHorizontalHeaderItem(0, headerItem) - headerItem = qt.QTableWidgetItem('Kind') + headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title()) headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA) self.setHorizontalHeaderItem(1, headerItem) @@ -750,7 +752,7 @@ class StatsTable(_StatsWidgetBase, TableWidget): return else: item = self.sender() - self._updateStats(item) + self._updateStats(item, data_changed=True) # deal with stat items visibility if event is ItemChangedType.VISIBLE: if len(self._itemToTableItems(item).items()) > 0: @@ -812,7 +814,7 @@ class StatsTable(_StatsWidgetBase, TableWidget): self.setItem(row, column, tableItem) # Update table items content - self._updateStats(item) + self._updateStats(item, data_changed=True) # Listen for item changes # Using queued connection to avoid issue with sender @@ -845,10 +847,12 @@ class StatsTable(_StatsWidgetBase, TableWidget): self.clearContents() self.setRowCount(0) - def _updateStats(self, item): + def _updateStats(self, item, data_changed=False, roi_changed=False): """Update displayed information for given plot item :param item: The plot item + :param bool data_changed: is the item data changed. + :param bool roi_changed: is the associated roi changed. """ if item is None: return @@ -865,7 +869,8 @@ class StatsTable(_StatsWidgetBase, TableWidget): statsHandler = self.getStatsHandler() if statsHandler is not None: stats = statsHandler.calculate( - item, plot, self._statsOnVisibleData) + item, plot, self._statsOnVisibleData, + data_changed=data_changed, roi_changed=roi_changed) else: stats = {} @@ -895,7 +900,7 @@ class StatsTable(_StatsWidgetBase, TableWidget): for row in range(self.rowCount()): tableItem = self.item(row, 0) item = self._tableItemToItem(tableItem) - self._updateStats(item) + self._updateStats(item, data_changed=is_request) def _currentItemChanged(self, current, previous): """Handle change of selection in table and sync plot selection @@ -1392,7 +1397,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): """ return self._item_kind - def _setItem(self, item): + def _setItem(self, item, data_changed=True): if item is None: for stat_name, stat_widget in self._statQlineEdit.items(): stat_widget.setText('') @@ -1402,7 +1407,8 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): if plot is not None: statsValDict = self._statsHandler.calculate(item, plot, - self._statsOnVisibleData) + self._statsOnVisibleData, + data_changed=data_changed) for statName, statVal in list(statsValDict.items()): self._statQlineEdit[statName].setText(statVal) @@ -1417,7 +1423,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): items = list(filter(kind_filter, _items)) assert len(items) in (0, 1) _item = items[0] if len(items) == 1 else None - self._setItem(_item) + self._setItem(_item, data_changed=True) def _updateCurrentItem(self): self._updateItemObserve() @@ -1432,7 +1438,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget): def _removeItem(self, item): raise NotImplementedError('Display only the active item') - def _plotCurrentChanged(selfself, current): + def _plotCurrentChanged(self, current): raise NotImplementedError('Display only the active item') def _updateModeHasChanged(self): diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py index aa4921c..3298498 100644 --- a/silx/gui/plot/_BaseMaskToolsWidget.py +++ b/silx/gui/plot/_BaseMaskToolsWidget.py @@ -378,7 +378,7 @@ class BaseMaskToolsWidget(qt.QWidget): """ super(BaseMaskToolsWidget, self).__init__(parent) # register if the user as force a color for the corresponding mask level - self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=numpy.bool) + self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool) # overlays colors set by the user self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32) @@ -459,6 +459,18 @@ class BaseMaskToolsWidget(qt.QWidget): self._levelWidget.setVisible(self._multipleMasks != 'single') self._clearAllBtn.setVisible(self._multipleMasks != 'single') + def setMaskFileDirectory(self, path): + """Set the default directory to use by load/save GUI tools + + The directory is also updated by the user, if he change the location + of the dialog. + """ + self.maskFileDir = path + + def getMaskFileDirectory(self): + """Get the default directory used by load/save GUI tools""" + return self.maskFileDir + @property def maskFileDir(self): """The directory from which to load/save mask from/to files.""" diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py index 23c9dce..ebf775b 100644 --- a/silx/gui/plot/_utils/dtime_ticklayout.py +++ b/silx/gui/plot/_utils/dtime_ticklayout.py @@ -166,7 +166,7 @@ def setDateElement(dateTime, value, unit): def roundToElement(dateTime, unit): - """ Returns a copy of dateTime with the + """ Returns a copy of dateTime rounded to given unit :param datetime.datetime: date time object :param DtUnit unit: unit @@ -330,15 +330,19 @@ def niceDateTimeElement(value, unit, isRound=False): def findStartDate(dMin, dMax, nTicks): """ Rounds a date down to the nearest nice number of ticks """ - assert dMax > dMin, \ + assert dMax >= dMin, \ "dMin ({}) should come before dMax ({})".format(dMin, dMax) + if dMin == dMax: + # Fallback when range is smaller than microsecond resolution + return dMin, 1, DtUnit.MICRO_SECONDS + delta = dMax - dMin lengthSec = delta.total_seconds() _logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)" .format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY)) - length, unit = bestUnit(delta.total_seconds()) + length, unit = bestUnit(lengthSec) niceLength = niceDateTimeElement(length, unit) _logger.debug("Length: {:8.3f} {} (nice = {})" @@ -381,9 +385,9 @@ def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False): """ if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or unit == DtUnit.MICRO_SECONDS): - - # Month and years will be converted to integers - assert int(step) > 0, "Integer value or tickstep is 0" + # No support for fractional month or year and resolution is microsecond + # In those cases, make sure the step is at least 1 + step = max(1, step) else: assert step > 0, "tickstep is 0" diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py index ba69748..182ac78 100755 --- a/silx/gui/plot/actions/control.py +++ b/silx/gui/plot/actions/control.py @@ -50,7 +50,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "24/04/2018" +__date__ = "27/11/2020" from . import PlotAction import logging @@ -322,6 +322,7 @@ class ColormapAction(PlotAction): :param plot: :class:`.PlotWidget` instance on which to operate :param parent: See :class:`QAction` """ + def __init__(self, plot, parent=None): self._dialog = None # To store an instance of ColormapDialog super(ColormapAction, self).__init__( @@ -418,6 +419,7 @@ class ColorBarAction(PlotAction): :param plot: :class:`.PlotWidget` instance on which to operate :param parent: See :class:`QAction` """ + def __init__(self, plot, parent=None): self._dialog = None # To store an instance of ColorBar super(ColorBarAction, self).__init__( @@ -597,7 +599,7 @@ class ShowAxisAction(PlotAction): triggered=self._actionTriggered, checkable=True, parent=parent) - self.setChecked(self.plot._backend.isAxesDisplayed()) + self.setChecked(self.plot.isAxesDisplayed()) plot._sigAxesVisibilityChanged.connect(self.setChecked) def _actionTriggered(self, checked=False): @@ -632,3 +634,76 @@ class ClosePolygonInteractionAction(PlotAction): def _actionTriggered(self, checked=False): self.plot._eventHandler.validate() + + +class OpenGLAction(PlotAction): + """QAction controlling rendering of a :class:`.PlotWidget`. + + For now it can enable or not the OpenGL backend. + + :param plot: :class:`.PlotWidget` instance on which to operate + :param parent: See :class:`QAction` + """ + + def __init__(self, plot, parent=None): + # Uses two images for checked/unchecked states + self._states = { + "opengl": (icons.getQIcon('backend-opengl'), + "OpenGL rendering (fast)\nClick to disable OpenGL"), + "matplotlib": (icons.getQIcon('backend-opengl'), + "Matplotlib rendering (safe)\nClick to enable OpenGL"), + "unknown": (icons.getQIcon('backend-opengl'), + "Custom rendering") + } + + name = self._getBackendName(plot) + self.__state = name + icon, tooltip = self._states[name] + super(OpenGLAction, self).__init__( + plot, + icon=icon, + text='Enable/disable OpenGL rendering', + tooltip=tooltip, + triggered=self._actionTriggered, + checkable=True, + parent=parent) + + def _backendUpdated(self): + name = self._getBackendName(self.plot) + self.__state = name + icon, tooltip = self._states[name] + self.setIcon(icon) + self.setToolTip(tooltip) + self.setChecked(name == "opengl") + + def _getBackendName(self, plot): + backend = plot.getBackend() + name = type(backend).__name__.lower() + if "opengl" in name: + return "opengl" + elif "matplotlib" in name: + return "matplotlib" + else: + return "unknown" + + def _actionTriggered(self, checked=False): + plot = self.plot + name = self._getBackendName(self.plot) + if self.__state != name: + # THere is no event to know the backend was updated + # So here we check if there is a mismatch between the displayed state + # and the real state of the widget + self._backendUpdated() + return + if name != "opengl": + from silx.gui.utils import glutils + result = glutils.isOpenGLAvailable() + if not result: + qt.QMessageBox.critical(plot, "OpenGL rendering not available", result.error) + # Uncheck if needed + self._backendUpdated() + return + plot.setBackend("opengl") + else: + plot.setBackend("matplotlib") + self._backendUpdated() diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py index 43b3b3a..f728b7a 100644 --- a/silx/gui/plot/actions/io.py +++ b/silx/gui/plot/actions/io.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2019 European Synchrotron Radiation Facility +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,7 @@ from __future__ import division __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"] __license__ = "MIT" -__date__ = "12/07/2018" +__date__ = "25/09/2020" from . import PlotAction from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT @@ -224,6 +224,43 @@ class SaveAction(PlotAction): ylabel = item.getYLabel() or self.plot.getYAxis().getLabel() return xlabel, ylabel + def _get1dData(self, item): + "provide xdata, [ydata], xlabel, [ylabel] and manages error bars" + xlabel, ylabel = self._getAxesLabels(item) + x_data = item.getXData(copy=False) + y_data = item.getYData(copy=False) + x_err = item.getXErrorData(copy=False) + y_err = item.getYErrorData(copy=False) + labels = [ylabel] + data = [y_data] + + if x_err is not None: + if numpy.isscalar(x_err): + data.append(numpy.zeros_like(y_data) + x_err) + labels.append(xlabel + "_errors") + elif x_err.ndim == 1: + data.append(x_err) + labels.append(xlabel + "_errors") + elif x_err.ndim == 2: + data.append(x_err[0]) + labels.append(xlabel + "_errors_below") + data.append(x_err[1]) + labels.append(xlabel + "_errors_above") + + if y_err is not None: + if numpy.isscalar(y_err): + data.append(numpy.zeros_like(y_data) + y_err) + labels.append(ylabel + "_errors") + elif y_err.ndim == 1: + data.append(y_err) + labels.append(ylabel + "_errors") + elif y_err.ndim == 2: + data.append(y_err[0]) + labels.append(ylabel + "_errors_below") + data.append(y_err[1]) + labels.append(ylabel + "_errors_above") + return x_data, data, xlabel, labels + @staticmethod def _selectWriteableOutputGroup(filename, parent): if os.path.exists(filename) and os.path.isfile(filename) \ @@ -291,16 +328,15 @@ class SaveAction(PlotAction): # .npy or nxdata fmt, csvdelim, autoheader = ("", "", False) - xlabel, ylabel = self._getAxesLabels(curve) - if nameFilter == self.CURVE_FILTER_NXDATA: return self._saveCurveAsNXdata(curve, filename) + xdata, data, xlabel, labels = self._get1dData(curve) + try: save1D(filename, - curve.getXData(copy=False), - curve.getYData(copy=False), - xlabel, [ylabel], + xdata, data, + xlabel, labels, fmt=fmt, csvdelim=csvdelim, autoheader=autoheader) except IOError: @@ -328,13 +364,11 @@ class SaveAction(PlotAction): curve = curves[0] scanno = 1 try: - xlabel = curve.getXLabel() or plot.getGraphXLabel() - ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) + xdata, data, xlabel, labels = self._get1dData(curve) + specfile = savespec(filename, - curve.getXData(copy=False), - curve.getYData(copy=False), - xlabel, - ylabel, + xdata, data, + xlabel, labels, fmt="%.7g", scan_number=1, mode="w", write_file_header=True, close_file=False) @@ -345,13 +379,10 @@ class SaveAction(PlotAction): for curve in curves[1:]: try: scanno += 1 - xlabel = curve.getXLabel() or plot.getGraphXLabel() - ylabel = curve.getYLabel() or plot.getGraphYLabel(curve.getYAxis()) + xdata, data, xlabel, labels = self._get1dData(curve) specfile = savespec(specfile, - curve.getXData(copy=False), - curve.getYData(copy=False), - xlabel, - ylabel, + xdata, data, + xlabel, labels, fmt="%.7g", scan_number=scanno, write_file_header=False, close_file=False) @@ -629,7 +660,7 @@ class SaveAction(PlotAction): # Check for correct file extension # Extract file extensions as .something extensions = [ext[ext.find('.'):] for ext in - nameFilter[nameFilter.find('(')+1:-1].split()] + nameFilter[nameFilter.find('(') + 1:-1].split()] for ext in extensions: if (len(filename) > len(ext) and filename[-len(ext):].lower() == ext.lower()): diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py index bcc93a5..6fc1aa7 100755 --- a/silx/gui/plot/backends/BackendBase.py +++ b/silx/gui/plot/backends/BackendBase.py @@ -58,8 +58,8 @@ class BackendBase(object): self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)} self.__yAxisInverted = False self.__keepDataAspectRatio = False + self.__xAxisTimeSeries = False self._xAxisTimeZone = None - self._axesDisplayed = True # Store a weakref to get access to the plot state. self._setPlot(plot) @@ -457,14 +457,14 @@ class BackendBase(object): :rtype: bool """ - raise NotImplementedError() + return self.__xAxisTimeSeries def setXAxisTimeSeries(self, isTimeSeries): """Set whether the X-axis is a time series :param bool flag: True to switch to time series, False for regular axis. """ - raise NotImplementedError() + self.__xAxisTimeSeries = bool(isTimeSeries) def setXAxisLogarithmic(self, flag): """Set the X axis scale between linear and log. @@ -548,20 +548,17 @@ class BackendBase(object): """ raise NotImplementedError() - def setAxesDisplayed(self, displayed): - """Display or not the axes. + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + """Set the size of plot margins as ratios. - :param bool displayed: If `True` axes are displayed. If `False` axes - are not anymore visible and the margin used for them is removed. - """ - self._axesDisplayed = displayed + Values are expected in [0., 1.] - def isAxesDisplayed(self): - """private because in some case it is possible that one of the two axes - are displayed and not the other. - This only check status set to axes from the public API + :param float left: + :param float top: + :param float right: + :param float bottom: """ - return self._axesDisplayed + pass def setForegroundColors(self, foregroundColor, gridColor): """Set foreground and grid colors used to display this widget. diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py index 036e630..140672f 100755 --- a/silx/gui/plot/backends/BackendMatplotlib.py +++ b/silx/gui/plot/backends/BackendMatplotlib.py @@ -33,6 +33,7 @@ __date__ = "21/12/2018" import logging import datetime as dt +from typing import Tuple import numpy from pkg_resources import parse_version as _parse_version @@ -44,7 +45,7 @@ _logger = logging.getLogger(__name__) from ... import qt # First of all init matplotlib and set its backend -from ..matplotlib import FigureCanvasQTAgg +from ...utils.matplotlib import FigureCanvasQTAgg import matplotlib from matplotlib.container import Container from matplotlib.figure import Figure @@ -593,7 +594,7 @@ class BackendMatplotlib(BackendBase.BackendBase): if (len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]): - color = numpy.array(color, dtype=numpy.float) / 255. + color = numpy.array(color, dtype=numpy.float64) / 255. if yaxis == "right": axes = self.ax2 @@ -601,7 +602,7 @@ class BackendMatplotlib(BackendBase.BackendBase): else: axes = self.ax - picker = 3 + pickradius = 3 artists = [] # All the artists composing the curve @@ -627,7 +628,7 @@ class BackendMatplotlib(BackendBase.BackendBase): if hasattr(color, 'dtype') and len(color) == len(x): # scatter plot - if color.dtype not in [numpy.float32, numpy.float]: + if color.dtype not in [numpy.float32, numpy.float64]: actualColor = color / 255. else: actualColor = color @@ -639,7 +640,8 @@ class BackendMatplotlib(BackendBase.BackendBase): linestyle=linestyle, color=actualColor[0], linewidth=linewidth, - picker=picker, + picker=True, + pickradius=pickradius, marker=None) artists += list(curveList) @@ -647,7 +649,8 @@ class BackendMatplotlib(BackendBase.BackendBase): scatter = axes.scatter(x, y, color=actualColor, marker=marker, - picker=picker, + picker=True, + pickradius=pickradius, s=symbolsize**2) artists.append(scatter) @@ -665,7 +668,8 @@ class BackendMatplotlib(BackendBase.BackendBase): color=color, linewidth=linewidth, marker=symbol, - picker=picker, + picker=True, + pickradius=pickradius, markersize=symbolsize) artists += list(curveList) @@ -744,13 +748,13 @@ class BackendMatplotlib(BackendBase.BackendBase): color = numpy.array(color, copy=False) assert color.ndim == 2 and len(color) == len(x) - if color.dtype not in [numpy.float32, numpy.float]: + if color.dtype not in [numpy.float32, numpy.float64]: color = color.astype(numpy.float32) / 255. collection = TriMesh( Triangulation(x, y, triangles), alpha=alpha, - picker=0) # 0 enables picking on filled triangle + pickradius=0) # 0 enables picking on filled triangle collection.set_color(color) self.ax.add_collection(collection) @@ -893,7 +897,8 @@ class BackendMatplotlib(BackendBase.BackendBase): else: raise RuntimeError('A marker must at least have one coordinate') - line.set_picker(5) + line.set_picker(True) + line.set_pickradius(5) # All markers are overlays line.set_animated(True) @@ -1014,7 +1019,11 @@ class BackendMatplotlib(BackendBase.BackendBase): lambda item: item.isVisible() and item._backendRenderer is not None) count = len(items) for index, item in enumerate(items): - zorder = 1. + index / count + if item.getZValue() < 0.5: + # Make sure matplotlib z order is below the grid (with z=0.5) + zorder = 0.5 * index / count + else: # Make sure matplotlib z order is above the grid (> 0.5) + zorder = 1. + index / count if zorder != item._backendRenderer.get_zorder(): item._backendRenderer.set_zorder(zorder) @@ -1196,67 +1205,58 @@ class BackendMatplotlib(BackendBase.BackendBase): # Data <-> Pixel coordinates conversion - def _mplQtYAxisCoordConversion(self, y, asint=True): - """Qt origin (top) to/from matplotlib origin (bottom) conversion. + def _getDevicePixelRatio(self) -> float: + """Compatibility wrapper for devicePixelRatioF""" + return 1. - :param y: - :param bool asint: True to cast to int, False to keep as float + def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]: + """Convert matplotlib "display" space coord to Qt widget logical pixel + """ + ratio = self._getDevicePixelRatio() + # Convert from matplotlib origin (bottom) to Qt origin (top) + # and apply device pixel ratio + return x / ratio, (self.fig.get_window_extent().height - y) / ratio - :rtype: float + def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]: + """Convert Qt widget logical pixel to matplotlib "display" space coord """ - value = self.fig.get_window_extent().height - y - return int(value) if asint else value + ratio = self._getDevicePixelRatio() + # Apply device pixel ration and + # convert from Qt origin (top) to matplotlib origin (bottom) + return x * ratio, self.fig.get_window_extent().height - (y * ratio) def dataToPixel(self, x, y, axis): ax = self.ax2 if axis == "right" else self.ax - - pixels = ax.transData.transform_point((x, y)) - xPixel, yPixel = pixels.T - - # Convert from matplotlib origin (bottom) to Qt origin (top) - yPixel = self._mplQtYAxisCoordConversion(yPixel, asint=False) - - return xPixel, yPixel + displayPos = ax.transData.transform_point((x, y)).transpose() + return self._mplToQtPosition(*displayPos) def pixelToData(self, x, y, axis): ax = self.ax2 if axis == "right" else self.ax - - # Convert from Qt origin (top) to matplotlib origin (bottom) - y = self._mplQtYAxisCoordConversion(y, asint=False) - - inv = ax.transData.inverted() - x, y = inv.transform_point((x, y)) - return x, y + displayPos = self._qtToMplPosition(x, y) + return tuple(ax.transData.inverted().transform_point(displayPos)) def getPlotBoundsInPixels(self): bbox = self.ax.get_window_extent() # Warning this is not returning int... - return (int(bbox.xmin), - self._mplQtYAxisCoordConversion(bbox.ymax, asint=True), - int(bbox.width), - int(bbox.height)) + ratio = self._getDevicePixelRatio() + return tuple(int(value / ratio) for value in ( + bbox.xmin, + self.fig.get_window_extent().height - bbox.ymax, + bbox.width, + bbox.height)) - def setAxesDisplayed(self, displayed): - """Display or not the axes. + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + width, height = 1. - left - right, 1. - top - bottom + position = left, bottom, width, height + + # Toggle display of axes and viewbox rect + isFrameOn = position != (0., 0., 1., 1.) + self.ax.set_frame_on(isFrameOn) + self.ax2.set_frame_on(isFrameOn) + + self.ax.set_position(position) + self.ax2.set_position(position) - :param bool displayed: If `True` axes are displayed. If `False` axes - are not anymore visible and the margin used for them is removed. - """ - BackendBase.BackendBase.setAxesDisplayed(self, displayed) - if displayed: - # show axes and viewbox rect - self.ax.set_frame_on(True) - self.ax2.set_frame_on(True) - # set the default margins - self.ax.set_position([.15, .15, .75, .75]) - self.ax2.set_position([.15, .15, .75, .75]) - else: - # hide axes and viewbox rect - self.ax.set_frame_on(False) - self.ax2.set_frame_on(False) - # remove external margins - self.ax.set_position([0, 0, 1, 1]) - self.ax2.set_position([0, 0, 1, 1]) self._synchronizeBackgroundColors() self._synchronizeForegroundColors() self._plot._setDirtyPlot() @@ -1349,6 +1349,15 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): def postRedisplay(self): self._sigPostRedisplay.emit() + def _getDevicePixelRatio(self) -> float: + """Compatibility wrapper for devicePixelRatioF""" + if hasattr(self, 'devicePixelRatioF'): + ratio = self.devicePixelRatioF() + else: # Qt < 5.6 compatibility + ratio = float(self.devicePixelRatio()) + # Safety net: avoid returning 0 + return ratio if ratio != 0. else 1. + # Mouse event forwarding _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'} @@ -1356,17 +1365,14 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): def _onMousePress(self, event): button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None) if button is not None: - self._plot.onMousePress( - event.x, self._mplQtYAxisCoordConversion(event.y), - button) + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMousePress(int(x), int(y), button) def _onMouseMove(self, event): + x, y = self._mplToQtPosition(event.x, event.y) if self._graphCursor: position = self._plot.pixelToData( - event.x, - self._mplQtYAxisCoordConversion(event.y), - axis='left', - check=True) + x, y, axis='left', check=True) lineh, linev = self._graphCursor if position is not None: linev.set_visible(True) @@ -1380,19 +1386,17 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): self._plot._setDirtyPlot(overlayOnly=True) # onMouseMove must trigger replot if dirty flag is raised - self._plot.onMouseMove( - event.x, self._mplQtYAxisCoordConversion(event.y)) + self._plot.onMouseMove(int(x), int(y)) def _onMouseRelease(self, event): button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None) if button is not None: - self._plot.onMouseRelease( - event.x, self._mplQtYAxisCoordConversion(event.y), - button) + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMouseRelease(int(x), int(y), button) def _onMouseWheel(self, event): - self._plot.onMouseWheel( - event.x, self._mplQtYAxisCoordConversion(event.y), event.step) + x, y = self._mplToQtPosition(event.x, event.y) + self._plot.onMouseWheel(int(x), int(y), event.step) def leaveEvent(self, event): """QWidget event handler""" @@ -1406,8 +1410,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib): # picking def pickItem(self, x, y, item): + xDisplay, yDisplay = self._qtToMplPosition(x, y) mouseEvent = MouseEvent( - 'button_press_event', self, x, self._mplQtYAxisCoordConversion(y)) + 'button_press_event', self, int(xDisplay), int(yDisplay)) # Override axes and data position with the axes mouseEvent.inaxes = item.axes mouseEvent.xdata, mouseEvent.ydata = self.pixelToData( diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py index cf1da31..909d18a 100755 --- a/silx/gui/plot/backends/BackendOpenGL.py +++ b/silx/gui/plot/backends/BackendOpenGL.py @@ -43,12 +43,7 @@ from ... import qt from ..._glutils import gl from ... import _glutils as glu -from .glutils import ( - GLLines2D, GLPlotTriangles, - GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D, - mat4Ortho, mat4Identity, - LEFT, RIGHT, BOTTOM, TOP, - Text2D, FilledShape2D) +from . import glutils from .glutils.PlotImageFile import saveImageToFile _logger = logging.getLogger(__name__) @@ -216,7 +211,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._backgroundColor = 1., 1., 1., 1. self._dataBackgroundColor = 1., 1., 1., 1. - self.matScreenProj = mat4Identity() + self.matScreenProj = glutils.mat4Identity() self._progBase = glu.Program( _baseVertShd, _baseFragShd, attrib0='position') @@ -231,10 +226,13 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): self._glGarbageCollector = [] - self._plotFrame = GLPlotFrame2D( + self._plotFrame = glutils.GLPlotFrame2D( foregroundColor=(0., 0., 0., 1.), gridColor=(.7, .7, .7, 1.), - margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50}) + marginRatios=(.15, .1, .1, .15)) + self._plotFrame.size = ( # Init size with size int + int(self.getDevicePixelRatio() * 640), + int(self.getDevicePixelRatio() * 480)) # Make postRedisplay asynchronous using Qt signal self._sigPostRedisplay.connect( @@ -254,50 +252,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def mousePressEvent(self, event): if event.button() not in self._MOUSE_BTNS: return super(BackendOpenGL, self).mousePressEvent(event) - xPixel = event.x() * self.getDevicePixelRatio() - yPixel = event.y() * self.getDevicePixelRatio() - btn = self._MOUSE_BTNS[event.button()] - self._plot.onMousePress(xPixel, yPixel, btn) + self._plot.onMousePress( + event.x(), event.y(), self._MOUSE_BTNS[event.button()]) event.accept() def mouseMoveEvent(self, event): - xPixel = event.x() * self.getDevicePixelRatio() - yPixel = event.y() * self.getDevicePixelRatio() - - # Handle crosshair - inXPixel, inYPixel = self._mouseInPlotArea(xPixel, yPixel) - isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel + qtPos = event.x(), event.y() previousMousePosInPixels = self._mousePosInPixels - self._mousePosInPixels = (xPixel, yPixel) if isCursorInPlot else None + if qtPos == self._mouseInPlotArea(*qtPos): + devicePixelRatio = self.getDevicePixelRatio() + devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio + self._mousePosInPixels = devicePos # Mouse in plot area + else: + self._mousePosInPixels = None # Mouse outside plot area + if (self._crosshairCursor is not None and previousMousePosInPixels != self._mousePosInPixels): # Avoid replot when cursor remains outside plot area self._plot._setDirtyPlot(overlayOnly=True) - self._plot.onMouseMove(xPixel, yPixel) + self._plot.onMouseMove(*qtPos) event.accept() def mouseReleaseEvent(self, event): if event.button() not in self._MOUSE_BTNS: return super(BackendOpenGL, self).mouseReleaseEvent(event) - xPixel = event.x() * self.getDevicePixelRatio() - yPixel = event.y() * self.getDevicePixelRatio() - - btn = self._MOUSE_BTNS[event.button()] - self._plot.onMouseRelease(xPixel, yPixel, btn) + self._plot.onMouseRelease( + event.x(), event.y(), self._MOUSE_BTNS[event.button()]) event.accept() def wheelEvent(self, event): - xPixel = event.x() * self.getDevicePixelRatio() - yPixel = event.y() * self.getDevicePixelRatio() - if hasattr(event, 'angleDelta'): # Qt 5 delta = event.angleDelta().y() else: # Qt 4 support delta = event.delta() angleInDegrees = delta / 8. - self._plot.onMouseWheel(xPixel, yPixel, angleInDegrees) + self._plot.onMouseWheel(event.x(), event.y(), angleInDegrees) event.accept() def leaveEvent(self, _): @@ -371,7 +362,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glUniform1i(self._progTex.uniforms['tex'], texUnit) gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE, - mat4Identity().astype(numpy.float32)) + glutils.mat4Identity().astype(numpy.float32)) gl.glEnableVertexAttribArray(self._progTex.attributes['position']) gl.glVertexAttribPointer(self._progTex.attributes['position'], @@ -405,10 +396,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT) # Check if window is large enough - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] - if plotWidth <= 2 or plotHeight <= 2: + if self._plotFrame.plotSize <= (2, 2): return + # Sync plot frame with window + self._plotFrame.devicePixelRatio = self.getDevicePixelRatio() # self._paintDirectGL() self._paintFBOGL() @@ -422,7 +414,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): True to render items that are overlays. """ # Values that are often used - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + plotWidth, plotHeight = self._plotFrame.plotSize isXLog = self._plotFrame.xAxis.isLog isYLog = self._plotFrame.yAxis.isLog isYInverted = self._plotFrame.isYAxisInverted @@ -431,6 +423,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): labels = [] pixelOffset = 3 + context = glutils.RenderContext( + isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch()) + for plotItem in self.getItemsFromBackToFront( condition=lambda i: i.isVisible() and i.isOverlay() == overlay): if plotItem._backendRenderer is None: @@ -438,20 +433,16 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): item = plotItem._backendRenderer - if isinstance(item, (GLPlotCurve2D, - GLPlotColormap, - GLPlotRGBAImage, - GLPlotTriangles)): # Render data items + if isinstance(item, glutils.GLPlotItem): # Render data items gl.glViewport(self._plotFrame.margins.left, self._plotFrame.margins.bottom, plotWidth, plotHeight) - - if isinstance(item, GLPlotCurve2D) and item.info.get('yAxis') == 'right': - item.render(self._plotFrame.transformedDataY2ProjMat, - isXLog, isYLog) + # Set matrix + if item.yaxis == 'right': + context.matrix = self._plotFrame.transformedDataY2ProjMat else: - item.render(self._plotFrame.transformedDataProjMat, - isXLog, isYLog) + context.matrix = self._plotFrame.transformedDataProjMat + item.render(context) elif isinstance(item, _ShapeItem): # Render shape items gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) @@ -463,53 +454,67 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if item['shape'] == 'hline': width = self._plotFrame.size[0] - _, yPixel = self._plot.dataToPixel( - None, item['y'], axis='left', check=False) - points = numpy.array(((0., yPixel), (width, yPixel)), - dtype=numpy.float32) + _, yPixel = self._plotFrame.dataToPixel( + 0.5 * sum(self._plotFrame.dataRanges[0]), + item['y'], + axis='left') + subShapes = [numpy.array(((0., yPixel), (width, yPixel)), + dtype=numpy.float32)] elif item['shape'] == 'vline': - xPixel, _ = self._plot.dataToPixel( - item['x'], None, axis='left', check=False) + xPixel, _ = self._plotFrame.dataToPixel( + item['x'], + 0.5 * sum(self._plotFrame.dataRanges[1]), + axis='left') height = self._plotFrame.size[1] - points = numpy.array(((xPixel, 0), (xPixel, height)), - dtype=numpy.float32) + subShapes = [numpy.array(((xPixel, 0), (xPixel, height)), + dtype=numpy.float32)] else: - points = numpy.array([ - self._plot.dataToPixel(x, y, axis='left', check=False) - for (x, y) in zip(item['x'], item['y'])]) - - # Draw the fill - if (item['fill'] is not None and - item['shape'] not in ('hline', 'vline')): - self._progBase.use() - gl.glUniformMatrix4fv( - self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, - self.matScreenProj.astype(numpy.float32)) - gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) - gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) - - shape2D = FilledShape2D( - points, style=item['fill'], color=item['color']) - shape2D.render( - posAttrib=self._progBase.attributes['position'], - colorUnif=self._progBase.uniforms['color'], - hatchStepUnif=self._progBase.uniforms['hatchStep']) - - # Draw the stroke - if item['linestyle'] not in ('', ' ', None): - if item['shape'] != 'polylines': - # close the polyline - points = numpy.append(points, - numpy.atleast_2d(points[0]), axis=0) - - lines = GLLines2D(points[:, 0], points[:, 1], - style=item['linestyle'], - color=item['color'], - dash2ndColor=item['linebgcolor'], - width=item['linewidth']) - lines.render(self.matScreenProj) + # Split sub-shapes at not finite values + splits = numpy.nonzero(numpy.logical_not(numpy.logical_and( + numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0] + splits = numpy.concatenate(([-1], splits, [len(item['x'])])) + subShapes = [] + for begin, end in zip(splits[:-1] + 1, splits[1:]): + if end > begin: + subShapes.append(numpy.array([ + self._plotFrame.dataToPixel(x, y, axis='left') + for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])])) + + for points in subShapes: # Draw each sub-shape + # Draw the fill + if (item['fill'] is not None and + item['shape'] not in ('hline', 'vline')): + self._progBase.use() + gl.glUniformMatrix4fv( + self._progBase.uniforms['matrix'], 1, gl.GL_TRUE, + self.matScreenProj.astype(numpy.float32)) + gl.glUniform2i(self._progBase.uniforms['isLog'], False, False) + gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.) + + shape2D = glutils.FilledShape2D( + points, style=item['fill'], color=item['color']) + shape2D.render( + posAttrib=self._progBase.attributes['position'], + colorUnif=self._progBase.uniforms['color'], + hatchStepUnif=self._progBase.uniforms['hatchStep']) + + # Draw the stroke + if item['linestyle'] not in ('', ' ', None): + if item['shape'] != 'polylines': + # close the polyline + points = numpy.append(points, + numpy.atleast_2d(points[0]), axis=0) + + lines = glutils.GLLines2D( + points[:, 0], points[:, 1], + style=item['linestyle'], + color=item['color'], + dash2ndColor=item['linebgcolor'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) elif isinstance(item, _MarkerItem): gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1]) @@ -522,76 +527,103 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): continue if xCoord is None or yCoord is None: - pixelPos = self._plot.dataToPixel( - xCoord, yCoord, axis=yAxis, check=False) - if xCoord is None: # Horizontal line in data space + pixelPos = self._plotFrame.dataToPixel( + 0.5 * sum(self._plotFrame.dataRanges[0]), + yCoord, + axis=yAxis) + if item['text'] is not None: x = self._plotFrame.size[0] - \ self._plotFrame.margins.right - pixelOffset y = pixelPos[1] - pixelOffset - label = Text2D(item['text'], x, y, - color=item['color'], - bgColor=(1., 1., 1., 0.5), - align=RIGHT, valign=BOTTOM) + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=(1., 1., 1., 0.5), + align=glutils.RIGHT, + valign=glutils.BOTTOM, + devicePixelRatio=self.getDevicePixelRatio()) labels.append(label) width = self._plotFrame.size[0] - lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]), - style=item['linestyle'], - color=item['color'], - width=item['linewidth']) - lines.render(self.matScreenProj) + lines = glutils.GLLines2D( + (0, width), (pixelPos[1], pixelPos[1]), + style=item['linestyle'], + color=item['color'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) else: # yCoord is None: vertical line in data space + yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2] + pixelPos = self._plotFrame.dataToPixel( + xCoord, 0.5 * sum(yRange), axis=yAxis) + if item['text'] is not None: x = pixelPos[0] + pixelOffset y = self._plotFrame.margins.top + pixelOffset - label = Text2D(item['text'], x, y, - color=item['color'], - bgColor=(1., 1., 1., 0.5), - align=LEFT, valign=TOP) + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=(1., 1., 1., 0.5), + align=glutils.LEFT, + valign=glutils.TOP, + devicePixelRatio=self.getDevicePixelRatio()) labels.append(label) height = self._plotFrame.size[1] - lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height), - style=item['linestyle'], - color=item['color'], - width=item['linewidth']) - lines.render(self.matScreenProj) + lines = glutils.GLLines2D( + (pixelPos[0], pixelPos[0]), (0, height), + style=item['linestyle'], + color=item['color'], + width=item['linewidth']) + context.matrix = self.matScreenProj + lines.render(context) else: - pixelPos = self._plot.dataToPixel( - xCoord, yCoord, axis=yAxis, check=True) - if pixelPos is None: + xmin, xmax = self._plot.getXAxis().getLimits() + ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits() + if not xmin < xCoord < xmax or not ymin < yCoord < ymax: # Do not render markers outside visible plot area continue + pixelPos = self._plotFrame.dataToPixel( + xCoord, yCoord, axis=yAxis) if isYInverted: - valign = BOTTOM + valign = glutils.BOTTOM vPixelOffset = -pixelOffset else: - valign = TOP + valign = glutils.TOP vPixelOffset = pixelOffset if item['text'] is not None: x = pixelPos[0] + pixelOffset y = pixelPos[1] + vPixelOffset - label = Text2D(item['text'], x, y, - color=item['color'], - bgColor=(1., 1., 1., 0.5), - align=LEFT, valign=valign) + label = glutils.Text2D( + item['text'], x, y, + color=item['color'], + bgColor=(1., 1., 1., 0.5), + align=glutils.LEFT, + valign=valign, + devicePixelRatio=self.getDevicePixelRatio()) labels.append(label) # For now simple implementation: using a curve for each marker # Should pack all markers to a single set of points - markerCurve = GLPlotCurve2D( + markerCurve = glutils.GLPlotCurve2D( numpy.array((pixelPos[0],), dtype=numpy.float64), numpy.array((pixelPos[1],), dtype=numpy.float64), marker=item['symbol'], markerColor=item['color'], markerSize=11) - markerCurve.render(self.matScreenProj, False, False) + + context = glutils.RenderContext( + matrix=self.matScreenProj, + isXLog=False, + isYLog=False, + dpi=self.getDotsPerInch()) + markerCurve.render(context) markerCurve.discard() else: @@ -605,7 +637,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): def _renderOverlayGL(self): """Render overlay layer: overlay items and crosshair.""" - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + plotWidth, plotHeight = self._plotFrame.plotSize # Scissor to plot area gl.glScissor(self._plotFrame.margins.left, @@ -658,7 +690,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): It renders the background, grid and items except overlays """ - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + plotWidth, plotHeight = self._plotFrame.plotSize gl.glScissor(self._plotFrame.margins.left, self._plotFrame.margins.bottom, @@ -687,9 +719,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): int(self.getDevicePixelRatio() * width), int(self.getDevicePixelRatio() * height)) - self.matScreenProj = mat4Ortho(0, self._plotFrame.size[0], - self._plotFrame.size[1], 0, - 1, -1) + self.matScreenProj = glutils.mat4Ortho( + 0, self._plotFrame.size[0], + self._plotFrame.size[1], 0, + 1, -1) # Store current ranges previousXRange = self.getGraphXLimits() @@ -824,21 +857,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): fillColor = None if fill is True: fillColor = color - curve = GLPlotCurve2D(x, y, colorArray, - xError=xerror, - yError=yerror, - lineStyle=linestyle, - lineColor=color, - lineWidth=linewidth, - marker=symbol, - markerColor=color, - markerSize=symbolsize, - fillColor=fillColor, - baseline=baseline, - isYLog=isYLog) - curve.info = { - 'yAxis': 'left' if yaxis is None else yaxis, - } + curve = glutils.GLPlotCurve2D( + x, y, colorArray, + xError=xerror, + yError=yerror, + lineStyle=linestyle, + lineColor=color, + lineWidth=linewidth, + marker=symbol, + markerColor=color, + markerSize=symbolsize, + fillColor=fillColor, + baseline=baseline, + isYLog=isYLog) + curve.yaxis = 'left' if yaxis is None else yaxis if yaxis == "right": self._plotFrame.isY2Axis = True @@ -853,7 +885,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if data.ndim == 2: # Ensure array is contiguous and eventually convert its type - if data.dtype in (numpy.float32, numpy.uint8, numpy.uint16): + dtypes = [dtype for dtype in ( + numpy.float32, numpy.float16, numpy.uint8, numpy.uint16) + if glu.isSupportedGLType(dtype)] + if data.dtype in dtypes: data = numpy.array(data, copy=False, order='C') else: _logger.info( @@ -861,24 +896,27 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): data = numpy.array(data, dtype=numpy.float32, order='C') normalization = colormap.getNormalization() - if normalization in GLPlotColormap.SUPPORTED_NORMALIZATIONS: + if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS: # Fast path applying colormap on the GPU cmapRange = colormap.getColormapRange(data=data) colormapLut = colormap.getNColors(nbColors=256) gamma = colormap.getGammaNormalizationParameter() - - image = GLPlotColormap(data, - origin, - scale, - colormapLut, - normalization, - gamma, - cmapRange, - alpha) + nanColor = colors.rgba(colormap.getNaNColor()) + + image = glutils.GLPlotColormap( + data, + origin, + scale, + colormapLut, + normalization, + gamma, + cmapRange, + alpha, + nanColor) else: # Fallback applying colormap on CPU rgba = colormap.applyToData(data) - image = GLPlotRGBAImage(rgba, origin, scale, alpha) + image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha) elif len(data.shape) == 3: # For RGB, RGBA data @@ -893,7 +931,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): else: raise ValueError('Unsupported data type') - image = GLPlotRGBAImage(data, origin, scale, alpha) + image = glutils.GLPlotRGBAImage(data, origin, scale, alpha) else: raise RuntimeError("Unsupported data shape {0}".format(data.shape)) @@ -916,7 +954,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if self._plotFrame.yAxis.isLog: y = numpy.log10(y) - triangles = GLPlotTriangles(x, y, color, triangles, alpha) + triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha) return triangles @@ -944,11 +982,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Remove methods def remove(self, item): - if isinstance(item, (GLPlotCurve2D, - GLPlotColormap, - GLPlotRGBAImage, - GLPlotTriangles)): - if isinstance(item, GLPlotCurve2D): + if isinstance(item, glutils.GLPlotItem): + if item.yaxis == 'right': # Check if some curves remains on the right Y axis y2AxisItems = (item for item in self._plot.getItems() if isinstance(item, items.YAxisMixIn) and @@ -997,13 +1032,18 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): _PICK_OFFSET = 3 # Offset in pixel used for picking def _mouseInPlotArea(self, x, y): - xPlot = numpy.clip( - x, self._plotFrame.margins.left, - self._plotFrame.size[0] - self._plotFrame.margins.right - 1) - yPlot = numpy.clip( - y, self._plotFrame.margins.top, - self._plotFrame.size[1] - self._plotFrame.margins.bottom - 1) - return xPlot, yPlot + """Returns closest visible position in the plot. + + This is performed in Qt widget pixel, not device pixel. + + :param float x: X coordinate in Qt widget pixel + :param float y: Y coordinate in Qt widget pixel + :return: (x, y) closest point in the plot. + :rtype: List[float] + """ + left, top, width, height = self.getPlotBoundsInPixels() + return (numpy.clip(x, left, left + width - 1), # TODO -1? + numpy.clip(y, top, top + height - 1)) def __pickCurves(self, item, x, y): """Perform picking on a curve item. @@ -1016,22 +1056,26 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): """ offset = self._PICK_OFFSET if item.marker is not None: - offset = max(item.markerSize / 2., offset) + # Convert markerSize from points to qt pixels + qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio() + size = item.markerSize / 72. * qtDpi + offset = max(size / 2., offset) if item.lineStyle is not None: - offset = max(item.lineWidth / 2., offset) - - yAxis = item.info['yAxis'] + # Convert line width from points to qt pixels + qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio() + lineWidth = item.lineWidth / 72. * qtDpi + offset = max(lineWidth / 2., offset) inAreaPos = self._mouseInPlotArea(x - offset, y - offset) dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) + axis=item.yaxis, check=True) if dataPos is None: return None xPick0, yPick0 = dataPos inAreaPos = self._mouseInPlotArea(x + offset, y + offset) dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1], - axis=yAxis, check=True) + axis=item.yaxis, check=True) if dataPos is None: return None xPick1, yPick1 = dataPos @@ -1051,8 +1095,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): xPickMin = numpy.log10(xPickMin) xPickMax = numpy.log10(xPickMax) - if (yAxis == 'left' and self._plotFrame.yAxis.isLog) or ( - yAxis == 'right' and self._plotFrame.y2Axis.isLog): + if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or ( + item.yaxis == 'right' and self._plotFrame.y2Axis.isLog): yPickMin = numpy.log10(yPickMin) yPickMax = numpy.log10(yPickMax) @@ -1060,6 +1104,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): xPickMax, yPickMax) def pickItem(self, x, y, item): + # Picking is performed in Qt widget pixels not device pixels dataPos = self._plot.pixelToData(x, y, axis='left', check=True) if dataPos is None: return None # Outside plot area @@ -1100,17 +1145,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): return (0,) if isPicked else None # Pick image, curve, triangles - elif isinstance(item, (GLPlotCurve2D, - GLPlotColormap, - GLPlotRGBAImage, - GLPlotTriangles)): - if isinstance(item, (GLPlotColormap, GLPlotRGBAImage, GLPlotTriangles)): - return item.pick(*dataPos) # Might be None - - elif isinstance(item, GLPlotCurve2D): + elif isinstance(item, glutils.GLPlotItem): + if isinstance(item, glutils.GLPlotCurve2D): return self.__pickCurves(item, x, y) else: - return None + return item.pick(*dataPos) # Might be None # Update curve @@ -1184,8 +1223,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): if axis == 'left': self._plotFrame.yAxis.title = label else: # right axis - if label: - _logger.warning('Right axis label not implemented') + self._plotFrame.y2Axis.title = label # Graph limits @@ -1209,7 +1247,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): :param str keepDim: The dimension to maintain: 'x', 'y' or None. If None (the default), the dimension with the largest range. """ - plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:] + plotWidth, plotHeight = self._plotFrame.plotSize if plotWidth <= 2 or plotHeight <= 2: return @@ -1352,17 +1390,25 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget): # Data <-> Pixel coordinates conversion def dataToPixel(self, x, y, axis): - return self._plotFrame.dataToPixel(x, y, axis) + result = self._plotFrame.dataToPixel(x, y, axis) + if result is None: + return None + else: + devicePixelRatio = self.getDevicePixelRatio() + return tuple(value/devicePixelRatio for value in result) def pixelToData(self, x, y, axis): - return self._plotFrame.pixelToData(x, y, axis) + devicePixelRatio = self.getDevicePixelRatio() + return self._plotFrame.pixelToData( + x * devicePixelRatio, y * devicePixelRatio, axis) def getPlotBoundsInPixels(self): - return self._plotFrame.plotOrigin + self._plotFrame.plotSize + devicePixelRatio = self.getDevicePixelRatio() + return tuple(int(value / devicePixelRatio) + for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize) - def setAxesDisplayed(self, displayed): - BackendBase.BackendBase.setAxesDisplayed(self, displayed) - self._plotFrame.displayed = displayed + def setAxesMargins(self, left: float, top: float, right: float, bottom: float): + self._plotFrame.marginRatios = left, top, right, bottom def setForegroundColors(self, foregroundColor, gridColor): self._plotFrame.foregroundColor = foregroundColor diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py index 9ab85fd..c4e2c1e 100644 --- a/silx/gui/plot/backends/glutils/GLPlotCurve.py +++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py @@ -43,6 +43,7 @@ from silx.math.combo import min_max from ...._glutils import gl from ...._glutils import Program, vertexBuffer, VertexBufferAttrib from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate +from .GLPlotImage import GLPlotItem _logger = logging.getLogger(__name__) @@ -172,10 +173,10 @@ class _Fill2D(object): self._xFillVboData, self._yFillVboData = vertexBuffer(points.T) - def render(self, matrix): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use + :param RenderContext context: """ self.prepare() @@ -186,7 +187,7 @@ class _Fill2D(object): gl.glUniformMatrix4fv( self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, - numpy.dot(matrix, + numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(numpy.float32)) gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color) @@ -404,11 +405,13 @@ class GLLines2D(object): """OpenGL context initialization""" gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST) - def render(self, matrix): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use + :param RenderContext context: """ + width = self.width / 72. * context.dpi + style = self.style if style is None: return @@ -425,7 +428,7 @@ class GLLines2D(object): gl.glUniform2f(program.uniforms['halfViewportSize'], 0.5 * viewWidth, 0.5 * viewHeight) - dashPeriod = self.dashPeriod * self.width + dashPeriod = self.dashPeriod * width if self.style == DOTTED: dash = (0.2 * dashPeriod, 0.5 * dashPeriod, @@ -463,10 +466,10 @@ class GLLines2D(object): 0, self.distVboData) - if self.width != 1: + if width != 1: gl.glEnable(gl.GL_LINE_SMOOTH) - matrix = numpy.dot(matrix, + matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(numpy.float32) gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix) @@ -503,7 +506,7 @@ class GLLines2D(object): 0, self.yVboData) - gl.glLineWidth(self.width) + gl.glLineWidth(width) gl.glDrawArrays(self._drawMode, 0, self.xVboData.size) gl.glDisable(gl.GL_LINE_SMOOTH) @@ -516,10 +519,26 @@ def distancesFromArrays(xData, yData): :param numpy.ndarray yData: Y coordinate of points :rtype: numpy.ndarray """ - deltas = numpy.dstack(( - numpy.ediff1d(xData, to_begin=numpy.float32(0.)), - numpy.ediff1d(yData, to_begin=numpy.float32(0.))))[0] - return numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))) + # Split array into sub-shapes at not finite points + splits = numpy.nonzero(numpy.logical_not(numpy.logical_and( + numpy.isfinite(xData), numpy.isfinite(yData))))[0] + splits = numpy.concatenate(([-1], splits, [len(xData) - 1])) + + # Compute distance independently for each sub-shapes, + # putting not finite points as last points of sub-shapes + distances = [] + for begin, end in zip(splits[:-1] + 1, splits[1:] + 1): + if begin == end: # Empty shape + continue + elif end - begin == 1: # Single element + distances.append([0]) + else: + deltas = numpy.dstack(( + numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)), + numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0] + distances.append( + numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1)))) + return numpy.concatenate(distances) # points ###################################################################### @@ -833,10 +852,10 @@ class _Points2D(object): if majorVersion >= 3: # OpenGL 3 gl.glEnable(gl.GL_PROGRAM_POINT_SIZE) - def render(self, matrix): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use + :param RenderContext context: """ if self.marker is None: return @@ -844,7 +863,7 @@ class _Points2D(object): program = self._getProgram(self.marker) program.use() - matrix = numpy.dot(matrix, + matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(numpy.float32) gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix) @@ -854,6 +873,13 @@ class _Points2D(object): size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point else: size = self.size + size = size / 72. * context.dpi + + if self.marker in (PLUS, H_LINE, V_LINE, + TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN): + # Convert to nearest odd number + size = size // 2 * 2 + 1. + gl.glUniform1f(program.uniforms['size'], size) # gl.glPointSize(self.size) @@ -1021,17 +1047,17 @@ class _ErrorBars(object): self._yErrPoints.yVboData.offset += (yAttrib.itemsize * yAttrib.size // 2) - def render(self, matrix): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use + :param RenderContext context: """ self.prepare() if self._attribs is not None: - self._lines.render(matrix) - self._xErrPoints.render(matrix) - self._yErrPoints.render(matrix) + self._lines.render(context) + self._xErrPoints.render(context) + self._yErrPoints.render(context) def discard(self): """Release VBOs""" @@ -1067,7 +1093,7 @@ def _proxyProperty(*componentsAttributes): return property(getter, setter) -class GLPlotCurve2D(object): +class GLPlotCurve2D(GLPlotItem): def __init__(self, xData, yData, colorData=None, xError=None, yError=None, lineStyle=SOLID, @@ -1080,7 +1106,7 @@ class GLPlotCurve2D(object): fillColor=None, baseline=None, isYLog=False): - + super().__init__() self.colorData = colorData # Compute x bounds @@ -1220,19 +1246,17 @@ class GLPlotCurve2D(object): self.colorVboData = cAttrib self.useColorVboData = cAttrib is not None - def render(self, matrix, isXLog, isYLog): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use - :param bool isXLog: - :param bool isYLog: + :param RenderContext context: Rendering information """ self.prepare() if self.fill is not None: - self.fill.render(matrix) - self._errorBars.render(matrix) - self.lines.render(matrix) - self.points.render(matrix) + self.fill.render(context) + self._errorBars.render(context) + self.lines.render(context) + self.points.render(context) def discard(self): """Release VBOs""" diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py index 43f6e10..c5ee75b 100644 --- a/silx/gui/plot/backends/glutils/GLPlotFrame.py +++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py @@ -61,7 +61,7 @@ class PlotAxis(object): This class is intended to be used with :class:`GLPlotFrame`. """ - def __init__(self, plot, + def __init__(self, plotFrame, tickLength=(0., 0.), foregroundColor=(0., 0., 0., 1.0), labelAlign=CENTER, labelVAlign=CENTER, @@ -69,7 +69,7 @@ class PlotAxis(object): titleRotate=0, titleOffset=(0., 0.)): self._ticks = None - self._plot = weakref.ref(plot) + self._plotFrameRef = weakref.ref(plotFrame) self._isDateTime = False self._timeZone = None @@ -156,6 +156,12 @@ class PlotAxis(object): self._displayCoords = displayCoords self._dirtyTicks() + @property + def devicePixelRatio(self): + """Returns the ratio between qt pixels and device pixels.""" + plotFrame = self._plotFrameRef() + return plotFrame.devicePixelRatio if plotFrame is not None else 1. + @property def title(self): """The text label associated with this axis as a str in latin-1.""" @@ -165,10 +171,18 @@ class PlotAxis(object): def title(self, title): if title != self._title: self._title = title + self._dirtyPlotFrame() - plot = self._plot() - if plot is not None: - plot._dirty() + @property + def titleOffset(self): + """Title offset in pixels (x: int, y: int)""" + return self._titleOffset + + @titleOffset.setter + def titleOffset(self, offset): + if offset != self._titleOffset: + self._titleOffset = offset + self._dirtyTicks() @property def foregroundColor(self): @@ -201,6 +215,8 @@ class PlotAxis(object): tickLabelsSize = [0., 0.] xTickLength, yTickLength = self._tickLength + xTickLength *= self.devicePixelRatio + yTickLength *= self.devicePixelRatio for (xPixel, yPixel), dataPos, text in self.ticks: if text is None: tickScale = 0.5 @@ -212,7 +228,8 @@ class PlotAxis(object): x=xPixel - xTickLength, y=yPixel - yTickLength, align=self._labelAlign, - valign=self._labelVAlign) + valign=self._labelVAlign, + devicePixelRatio=self.devicePixelRatio) width, height = label.size if width > tickLabelsSize[0]: @@ -230,7 +247,7 @@ class PlotAxis(object): xAxisCenter = 0.5 * (x0 + x1) yAxisCenter = 0.5 * (y0 + y1) - xOffset, yOffset = self._titleOffset + xOffset, yOffset = self.titleOffset # Adaptative title positioning: # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2) @@ -245,17 +262,22 @@ class PlotAxis(object): y=yAxisCenter + yOffset, align=self._titleAlign, valign=self._titleVAlign, - rotate=self._titleRotate) + rotate=self._titleRotate, + devicePixelRatio=self.devicePixelRatio) labels.append(axisTitle) return vertices, labels + def _dirtyPlotFrame(self): + """Dirty parent GLPlotFrame""" + plotFrame = self._plotFrameRef() + if plotFrame is not None: + plotFrame._dirty() + def _dirtyTicks(self): """Mark ticks as dirty and notify listener (i.e., background).""" self._ticks = None - plot = self._plot() - if plot is not None: - plot._dirty() + self._dirtyPlotFrame() @staticmethod def _frange(start, stop, step): @@ -314,7 +336,7 @@ class PlotAxis(object): xScale = (x1 - x0) / (dataMax - dataMin) yScale = (y1 - y0) / (dataMax - dataMin) - nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) + nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio # Density of 1.3 label per 92 pixels # i.e., 1.3 label per inch on a 92 dpi screen @@ -391,11 +413,11 @@ class GLPlotFrame(object): # Margins used when plot frame is not displayed _NoDisplayMargins = _Margins(0, 0, 0, 0) - def __init__(self, margins, foregroundColor, gridColor): + def __init__(self, marginRatios, foregroundColor, gridColor): """ - :param margins: The margins around plot area for axis and labels. - :type margins: dict with 'left', 'right', 'top', 'bottom' keys and - values as ints. + :param List[float] marginRatios: + The ratios of margins around plot area for axis and labels. + (left, top, right, bottom) as float in [0., 1.] :param foregroundColor: color used for the frame and labels. :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 :param gridColor: color used for grid lines. @@ -403,7 +425,9 @@ class GLPlotFrame(object): """ self._renderResources = None - self._margins = self._Margins(**margins) + self.__marginRatios = marginRatios + self.__marginsCache = None + self._foregroundColor = foregroundColor self._gridColor = gridColor @@ -412,7 +436,8 @@ class GLPlotFrame(object): self._grid = False self._size = 0., 0. self._title = '' - self._displayed = True + + self._devicePixelRatio = 1. @property def isDirty(self): @@ -453,26 +478,49 @@ class GLPlotFrame(object): if self._gridColor != color: self._gridColor = color self._dirty() - + @property - def displayed(self): - """Whether axes and their labels are displayed or not (bool)""" - return self._displayed - - @displayed.setter - def displayed(self, displayed): - displayed = bool(displayed) - if displayed != self._displayed: - self._displayed = displayed + def marginRatios(self): + """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1]. + """ + return self.__marginRatios + + @marginRatios.setter + def marginRatios(self, ratios): + ratios = tuple(float(v) for v in ratios) + assert len(ratios) == 4 + for value in ratios: + assert 0. <= value <= 1. + assert ratios[0] + ratios[2] < 1. + assert ratios[1] + ratios[3] < 1. + + if self.__marginRatios != ratios: + self.__marginRatios = ratios + self.__marginsCache = None # Clear cached margins self._dirty() @property def margins(self): """Margins in pixels around the plot.""" - if not self.displayed: - return self._NoDisplayMargins - else: - return self._margins + if self.__marginsCache is None: + width, height = self.size + left, top, right, bottom = self.marginRatios + self.__marginsCache = self._Margins( + left=int(left*width), + right=int(right*width), + top=int(top*height), + bottom=int(bottom*height)) + return self.__marginsCache + + @property + def devicePixelRatio(self): + return self._devicePixelRatio + + @devicePixelRatio.setter + def devicePixelRatio(self, ratio): + if ratio != self._devicePixelRatio: + self._devicePixelRatio = ratio + self._dirty() @property def grid(self): @@ -493,7 +541,7 @@ class GLPlotFrame(object): @property def size(self): - """Size in pixels of the plot area including margins.""" + """Size in device pixels of the plot area including margins.""" return self._size @size.setter @@ -502,6 +550,7 @@ class GLPlotFrame(object): size = tuple(size) if size != self._size: self._size = size + self.__marginsCache = None # Clear cached margins self._dirty() @property @@ -580,7 +629,8 @@ class GLPlotFrame(object): x=xTitle, y=yTitle, align=CENTER, - valign=BOTTOM)) + valign=BOTTOM, + devicePixelRatio=self.devicePixelRatio)) # grid gridVertices = numpy.array(self._buildGridVertices(), @@ -592,7 +642,7 @@ class GLPlotFrame(object): _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position') def render(self): - if not self.displayed: + if self.margins == self._NoDisplayMargins: return if self._renderResources is None: @@ -661,25 +711,24 @@ class GLPlotFrame(object): # GLPlotFrame2D ############################################################### class GLPlotFrame2D(GLPlotFrame): - def __init__(self, margins, foregroundColor, gridColor): + def __init__(self, marginRatios, foregroundColor, gridColor): """ - :param margins: The margins around plot area for axis and labels. - :type margins: dict with 'left', 'right', 'top', 'bottom' keys and - values as ints. + :param List[float] marginRatios: + The ratios of margins around plot area for axis and labels. + (left, top, right, bottom) as float in [0., 1.] :param foregroundColor: color used for the frame and labels. :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0 :param gridColor: color used for grid lines. :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0 """ - super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor) + super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor) self.axes.append(PlotAxis(self, tickLength=(0., -5.), foregroundColor=self._foregroundColor, labelAlign=CENTER, labelVAlign=TOP, titleAlign=CENTER, titleVAlign=TOP, - titleRotate=0, - titleOffset=(0, self.margins.bottom // 2))) + titleRotate=0)) self._x2AxisCoords = () @@ -688,18 +737,14 @@ class GLPlotFrame2D(GLPlotFrame): foregroundColor=self._foregroundColor, labelAlign=RIGHT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=BOTTOM, - titleRotate=ROTATE_270, - titleOffset=(-3 * self.margins.left // 4, - 0))) + titleRotate=ROTATE_270)) self._y2Axis = PlotAxis(self, tickLength=(-5., 0.), foregroundColor=self._foregroundColor, labelAlign=LEFT, labelVAlign=CENTER, titleAlign=CENTER, titleVAlign=TOP, - titleRotate=ROTATE_270, - titleOffset=(3 * self.margins.right // 4, - 0)) + titleRotate=ROTATE_270) self._isYAxisInverted = False @@ -794,6 +839,24 @@ class GLPlotFrame2D(GLPlotFrame): self._baseVectors = vectors self._dirty() + def _updateTitleOffset(self): + """Update axes title offset according to margins""" + margins = self.margins + self.xAxis.titleOffset = 0, margins.bottom // 2 + self.yAxis.titleOffset = -3 * margins.left // 4, 0 + self.y2Axis.titleOffset = 3 * margins.right // 4, 0 + + # Override size and marginRatios setters to update titleOffsets + @GLPlotFrame.size.setter + def size(self, size): + GLPlotFrame.size.fset(self, size) + self._updateTitleOffset() + + @GLPlotFrame.marginRatios.setter + def marginRatios(self, ratios): + GLPlotFrame.marginRatios.fset(self, ratios) + self._updateTitleOffset() + @property def dataRanges(self): """Ranges of data visible in the plot on x, y and y2 axes. diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/silx/gui/plot/backends/glutils/GLPlotImage.py index e985a3d..f60a159 100644 --- a/silx/gui/plot/backends/glutils/GLPlotImage.py +++ b/silx/gui/plot/backends/glutils/GLPlotImage.py @@ -40,10 +40,12 @@ from ...._glutils import gl, Program, Texture from ..._utils import FLOAT32_MINPOS from .GLSupport import mat4Translate, mat4Scale from .GLTexture import Image +from .GLPlotItem import GLPlotItem -class _GLPlotData2D(object): +class _GLPlotData2D(GLPlotItem): def __init__(self, data, origin, scale): + super().__init__() self.data = data assert len(origin) == 2 self.origin = tuple(origin) @@ -80,15 +82,6 @@ class _GLPlotData2D(object): oy, sy = self.origin[1], self.scale[1] return oy + sy * self.data.shape[0] if sy >= 0. else oy - def discard(self): - pass - - def prepare(self): - pass - - def render(self, matrix, isXLog, isYLog): - pass - class GLPlotColormap(_GLPlotData2D): @@ -160,6 +153,11 @@ class GLPlotColormap(_GLPlotData2D): 'fragment': """ #version 120 + /* isnan declaration for compatibility with GLSL 1.20 */ + bool isnan(float value) { + return (value != value); + } + uniform sampler2D data; uniform sampler2D cmap_texture; uniform int cmap_normalization; @@ -167,6 +165,7 @@ class GLPlotColormap(_GLPlotData2D): uniform float cmap_min; uniform float cmap_oneOverRange; uniform float alpha; + uniform vec4 nancolor; varying vec2 coords; @@ -175,7 +174,8 @@ class GLPlotColormap(_GLPlotData2D): const float oneOverLog10 = 0.43429448190325176; void main(void) { - float value = texture2D(data, textureCoords()).r; + float data = texture2D(data, textureCoords()).r; + float value = data; if (cmap_normalization == 1) { /*Logarithm mapping*/ if (value > 0.) { value = clamp(cmap_oneOverRange * @@ -202,7 +202,11 @@ class GLPlotColormap(_GLPlotData2D): value = clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.); } - gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5)); + if (isnan(data)) { + gl_FragColor = nancolor; + } else { + gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5)); + } gl_FragColor.a *= alpha; } """ @@ -213,6 +217,7 @@ class GLPlotColormap(_GLPlotData2D): _INTERNAL_FORMATS = { numpy.dtype(numpy.float32): gl.GL_R32F, + numpy.dtype(numpy.float16): gl.GL_R16F, # Use normalized integer for unsigned int formats numpy.dtype(numpy.uint16): gl.GL_R16, numpy.dtype(numpy.uint8): gl.GL_R8, @@ -232,7 +237,7 @@ class GLPlotColormap(_GLPlotData2D): def __init__(self, data, origin, scale, colormap, normalization='linear', gamma=0., cmapRange=None, - alpha=1.0): + alpha=1.0, nancolor=(1., 1., 1., 0.)): """Create a 2D colormap :param data: The 2D scalar data array to display @@ -252,6 +257,8 @@ class GLPlotColormap(_GLPlotData2D): TODO: check consistency with matplotlib :type cmapRange: (float, float) or None :param float alpha: Opacity from 0 (transparent) to 1 (opaque) + :param nancolor: RGBA color for Not-A-Number values + :type nancolor: 4-tuple of float in [0., 1.] """ assert data.dtype in self._INTERNAL_FORMATS assert normalization in self.SUPPORTED_NORMALIZATIONS @@ -263,6 +270,7 @@ class GLPlotColormap(_GLPlotData2D): self._cmapRange = (1., 10.) # Colormap range self.cmapRange = cmapRange # Update _cmapRange self._alpha = numpy.clip(alpha, 0., 1.) + self._nancolor = numpy.clip(nancolor, 0., 1.) self._cmap_texture = None self._texture = None @@ -283,7 +291,7 @@ class GLPlotColormap(_GLPlotData2D): if self.normalization == 'log': assert self._cmapRange[0] > 0. and self._cmapRange[1] > 0. elif self.normalization == 'sqrt': - assert self._cmapRange[0] >= 0. and self._cmapRange[1] > 0. + assert self._cmapRange[0] >= 0. and self._cmapRange[1] >= 0. return self._cmapRange @cmapRange.setter @@ -324,6 +332,7 @@ class GLPlotColormap(_GLPlotData2D): magFilter=gl.GL_NEAREST, wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE)) + self._cmap_texture.prepare() if self._texture is None: internalFormat = self._INTERNAL_FORMATS[self.data.dtype] @@ -376,9 +385,15 @@ class GLPlotColormap(_GLPlotData2D): oneOverRange = 0. # Fall-back gl.glUniform1f(prog.uniforms['cmap_oneOverRange'], oneOverRange) + gl.glUniform4f(prog.uniforms['nancolor'], *self._nancolor) + self._cmap_texture.bind() - def _renderLinear(self, matrix): + def _renderLinear(self, context): + """Perform rendering when both axes have linear scales + + :param RenderContext context: Rendering information + """ self.prepare() prog = self._linearProgram @@ -386,7 +401,7 @@ class GLPlotColormap(_GLPlotData2D): gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) - mat = numpy.dot(numpy.dot(matrix, + mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)), mat4Scale(*self.scale)) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, @@ -400,10 +415,14 @@ class GLPlotColormap(_GLPlotData2D): prog.attributes['texCoords'], self._DATA_TEX_UNIT) - def _renderLog10(self, matrix, isXLog, isYLog): + def _renderLog10(self, context): + """Perform rendering when one axis has log scale + + :param RenderContext context: Rendering information + """ xMin, yMin = self.xMin, self.yMin - if ((isXLog and xMin < FLOAT32_MINPOS) or - (isYLog and yMin < FLOAT32_MINPOS)): + if ((context.isXLog and xMin < FLOAT32_MINPOS) or + (context.isYLog and yMin < FLOAT32_MINPOS)): # Do not render images that are partly or totally <= 0 return @@ -417,12 +436,12 @@ class GLPlotColormap(_GLPlotData2D): gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, - matrix.astype(numpy.float32)) + context.matrix.astype(numpy.float32)) mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat.astype(numpy.float32)) - gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) + gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog) ex = ox + self.scale[0] * self.data.shape[1] ey = oy + self.scale[1] * self.data.shape[0] @@ -461,11 +480,15 @@ class GLPlotColormap(_GLPlotData2D): gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) - def render(self, matrix, isXLog, isYLog): - if any((isXLog, isYLog)): - self._renderLog10(matrix, isXLog, isYLog) + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + if any((context.isXLog, context.isYLog)): + self._renderLog10(context) else: - self._renderLinear(matrix) + self._renderLinear(context) # Unbind colormap texture gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit) @@ -635,7 +658,11 @@ class GLPlotRGBAImage(_GLPlotData2D): format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB self._texture.updateAll(format_=format_, data=self.data) - def _renderLinear(self, matrix): + def _renderLinear(self, context): + """Perform rendering with both axes having linear scales + + :param RenderContext context: Rendering information + """ self.prepare() prog = self._linearProgram @@ -643,7 +670,7 @@ class GLPlotRGBAImage(_GLPlotData2D): gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) - mat = numpy.dot(numpy.dot(matrix, mat4Translate(*self.origin)), + mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)), mat4Scale(*self.scale)) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, mat.astype(numpy.float32)) @@ -654,7 +681,11 @@ class GLPlotRGBAImage(_GLPlotData2D): prog.attributes['texCoords'], self._DATA_TEX_UNIT) - def _renderLog(self, matrix, isXLog, isYLog): + def _renderLog(self, context): + """Perform rendering with axes having log scale + + :param RenderContext context: Rendering information + """ self.prepare() prog = self._logProgram @@ -665,12 +696,12 @@ class GLPlotRGBAImage(_GLPlotData2D): gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT) gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE, - matrix.astype(numpy.float32)) + context.matrix.astype(numpy.float32)) mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale)) gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE, mat.astype(numpy.float32)) - gl.glUniform2i(prog.uniforms['isLog'], isXLog, isYLog) + gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog) gl.glUniform1f(prog.uniforms['alpha'], self.alpha) @@ -707,8 +738,12 @@ class GLPlotRGBAImage(_GLPlotData2D): gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices)) - def render(self, matrix, isXLog, isYLog): - if any((isXLog, isYLog)): - self._renderLog(matrix, isXLog, isYLog) + def render(self, context): + """Perform rendering + + :param RenderContext context: Rendering information + """ + if any((context.isXLog, context.isYLog)): + self._renderLog(context) else: - self._renderLinear(matrix) + self._renderLinear(context) diff --git a/silx/gui/plot/backends/glutils/GLPlotItem.py b/silx/gui/plot/backends/glutils/GLPlotItem.py new file mode 100644 index 0000000..899f38e --- /dev/null +++ b/silx/gui/plot/backends/glutils/GLPlotItem.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ############################################################################*/ +""" +This module provides a base class for PlotWidget OpenGL backend primitives +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/07/2020" + + +class RenderContext: + """Context with which to perform OpenGL rendering. + + :param numpy.ndarray matrix: 4x4 transform matrix to use for rendering + :param bool isXLog: Whether X axis is log scale or not + :param bool isYLog: Whether Y axis is log scale or not + :param float dpi: Number of device pixels per inch + """ + + def __init__(self, matrix=None, isXLog=False, isYLog=False, dpi=96.): + self.matrix = matrix + """Current transformation matrix""" + + self.__isXLog = isXLog + self.__isYLog = isYLog + self.__dpi = dpi + + @property + def isXLog(self): + """True if X axis is using log scale""" + return self.__isXLog + + @property + def isYLog(self): + """True if Y axis is using log scale""" + return self.__isYLog + + @property + def dpi(self): + """Number of device pixels per inch""" + return self.__dpi + + +class GLPlotItem: + """Base class for primitives used in the PlotWidget OpenGL backend""" + + def __init__(self): + self.yaxis = 'left' + "YAxis this item is attached to (either 'left' or 'right')" + + def pick(self, x, y): + """Perform picking at given position. + + :param float x: X coordinate in plot data frame of reference + :param float y: Y coordinate in plot data frame of reference + :returns: + Result of picking as a list of indices or None if nothing picked + :rtype: Union[List[int],None] + """ + return None + + def render(self, context): + """Performs OpenGL rendering of the item. + + :param RenderContext context: Rendering context information + """ + pass + + def discard(self): + """Discards OpenGL resources this item has created.""" + pass diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/silx/gui/plot/backends/glutils/GLPlotTriangles.py index 7aeb5ab..d5ba1a6 100644 --- a/silx/gui/plot/backends/glutils/GLPlotTriangles.py +++ b/silx/gui/plot/backends/glutils/GLPlotTriangles.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2019 European Synchrotron Radiation Facility +# Copyright (c) 2019-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,9 +38,10 @@ import numpy from .....math.combo import min_max from .... import _glutils as glutils from ...._glutils import gl +from .GLPlotItem import GLPlotItem -class GLPlotTriangles(object): +class GLPlotTriangles(GLPlotItem): """Handle rendering of a set of colored triangles""" _PROGRAM = glutils.Program( @@ -81,6 +82,7 @@ class GLPlotTriangles(object): :param numpy.ndarray triangles: (N, 3) array of indices of triangles :param float alpha: Opacity in [0, 1] """ + super().__init__() # Check and convert input data x = numpy.ravel(numpy.array(x, dtype=numpy.float32)) y = numpy.ravel(numpy.array(y, dtype=numpy.float32)) @@ -161,12 +163,10 @@ class GLPlotTriangles(object): usage=gl.GL_STATIC_DRAW, target=gl.GL_ELEMENT_ARRAY_BUFFER) - def render(self, matrix, isXLog, isYLog): + def render(self, context): """Perform rendering - :param numpy.ndarray matrix: 4x4 transform matrix to use - :param bool isXLog: - :param bool isYLog: + :param RenderContext context: Rendering information """ self.prepare() @@ -178,7 +178,7 @@ class GLPlotTriangles(object): gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE, - matrix.astype(numpy.float32)) + context.matrix.astype(numpy.float32)) gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha) diff --git a/silx/gui/plot/backends/glutils/GLText.py b/silx/gui/plot/backends/glutils/GLText.py index 725c12c..d6ae6fa 100644 --- a/silx/gui/plot/backends/glutils/GLText.py +++ b/silx/gui/plot/backends/glutils/GLText.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2019 European Synchrotron Radiation Facility +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -140,7 +140,9 @@ class Text2D(object): color=(0., 0., 0., 1.), bgColor=None, align=LEFT, valign=BASELINE, - rotate=0): + rotate=0, + devicePixelRatio= 1.): + self.devicePixelRatio = devicePixelRatio self._vertices = None self._text = text self.x = x @@ -160,30 +162,35 @@ class Text2D(object): self._rotate = numpy.radians(rotate) - def _getTexture(self, text): + def _getTexture(self, text, devicePixelRatio): # Retrieve/initialize texture cache for current context + textureKey = text, devicePixelRatio + context = Context.getCurrent() if context not in self._textures: self._textures[context] = _Cache( callback=lambda key, value: value[0].discard()) textures = self._textures[context] - if text not in textures: - image, offset = font.rasterText(text, - font.getDefaultFontFamily()) - if text not in self._sizes: - self._sizes[text] = image.shape[1], image.shape[0] - - textures[text] = ( - Texture(gl.GL_RED, - data=image, - minFilter=gl.GL_NEAREST, - magFilter=gl.GL_NEAREST, - wrap=(gl.GL_CLAMP_TO_EDGE, - gl.GL_CLAMP_TO_EDGE)), - offset) - - return textures[text] + if textureKey not in textures: + image, offset = font.rasterText( + text, + font.getDefaultFontFamily(), + devicePixelRatio=self.devicePixelRatio) + if textureKey not in self._sizes: + self._sizes[textureKey] = image.shape[1], image.shape[0] + + texture = Texture( + gl.GL_RED, + data=image, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=(gl.GL_CLAMP_TO_EDGE, + gl.GL_CLAMP_TO_EDGE)) + texture.prepare() + textures[textureKey] = texture, offset + + return textures[textureKey] @property def text(self): @@ -191,11 +198,14 @@ class Text2D(object): @property def size(self): - if self.text not in self._sizes: - image, offset = font.rasterText(self.text, - font.getDefaultFontFamily()) - self._sizes[self.text] = image.shape[1], image.shape[0] - return self._sizes[self.text] + textureKey = self.text, self.devicePixelRatio + if textureKey not in self._sizes: + image, offset = font.rasterText( + self.text, + font.getDefaultFontFamily(), + devicePixelRatio=self.devicePixelRatio) + self._sizes[textureKey] = image.shape[1], image.shape[0] + return self._sizes[textureKey] def getVertices(self, offset, shape): height, width = shape @@ -238,7 +248,7 @@ class Text2D(object): prog.use() texUnit = 0 - texture, offset = self._getTexture(self.text) + texture, offset = self._getTexture(self.text, self.devicePixelRatio) gl.glUniform1i(prog.uniforms['texText'], texUnit) diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/silx/gui/plot/backends/glutils/GLTexture.py index 118a36f..37fbdd0 100644 --- a/silx/gui/plot/backends/glutils/GLTexture.py +++ b/silx/gui/plot/backends/glutils/GLTexture.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2019 European Synchrotron Radiation Facility +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -98,6 +98,7 @@ class Image(object): minFilter=self._MIN_FILTER, magFilter=self._MAG_FILTER, wrap=self._WRAP) + texture.prepare() vertices = numpy.array(( (0., 0., 0., 0.), (self.width, 0., 1., 0.), @@ -177,6 +178,7 @@ class Image(object): (xOrig, yOrig + hData, 0., vMax), (xOrig + wData, yOrig + hData, uMax, vMax)), dtype=numpy.float32) + texture.prepare() tiles.append((texture, vertices, {'xOrigData': xOrig, 'yOrigData': yOrig, 'wData': wData, 'hData': hData})) @@ -203,6 +205,7 @@ class Image(object): texture.update(format_, data[yOrig:yOrig+height, xOrig:xOrig+width], texUnit=texUnit) + texture.prepare() # TODO check # width=info['wData'], height=info['hData'], # texUnit=texUnit, unpackAlign=unpackAlign, diff --git a/silx/gui/plot/backends/glutils/__init__.py b/silx/gui/plot/backends/glutils/__init__.py index d58c084..f87d7c1 100644 --- a/silx/gui/plot/backends/glutils/__init__.py +++ b/silx/gui/plot/backends/glutils/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2014-2019 European Synchrotron Radiation Facility +# Copyright (c) 2014-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -39,6 +39,7 @@ _logger = logging.getLogger(__name__) from .GLPlotCurve import * # noqa from .GLPlotFrame import * # noqa from .GLPlotImage import * # noqa +from .GLPlotItem import GLPlotItem, RenderContext # noqa from .GLPlotTriangles import GLPlotTriangles # noqa from .GLSupport import * # noqa from .GLText import * # noqa diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py index 4d4eac0..0484025 100644 --- a/silx/gui/plot/items/__init__.py +++ b/silx/gui/plot/items/__init__.py @@ -32,7 +32,8 @@ __authors__ = ["T. Vincent"] __license__ = "MIT" __date__ = "22/06/2017" -from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa +from .core import (Item, DataItem, # noqa + LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa ComplexMixIn, ItemChangedType, PointsBase) # noqa diff --git a/silx/gui/plot/items/_arc_roi.py b/silx/gui/plot/items/_arc_roi.py new file mode 100644 index 0000000..a22cc3d --- /dev/null +++ b/silx/gui/plot/items/_arc_roi.py @@ -0,0 +1,873 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides Arc ROI item for the :class:`~silx.gui.plot.PlotWidget`. +""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "28/06/2018" + +import numpy + +from ... import utils +from .. import items +from ...colors import rgba +from ....utils.proxy import docstring +from ._roi_base import HandleBasedROI +from ._roi_base import InteractionModeMixIn +from ._roi_base import RoiInteractionMode + + +class _ArcGeometry: + """ + Non-mutable object to store the geometry of the arc ROI. + + The aim is is to switch between consistent state without dealing with + intermediate values. + """ + def __init__(self, center, startPoint, endPoint, radius, + weight, startAngle, endAngle, closed=False): + """Constructor for a consistent arc geometry. + + There is also specific class method to create different kind of arc + geometry. + """ + self.center = center + self.startPoint = startPoint + self.endPoint = endPoint + self.radius = radius + self.weight = weight + self.startAngle = startAngle + self.endAngle = endAngle + self._closed = closed + + @classmethod + def createEmpty(cls): + """Create an arc geometry from an empty shape + """ + zero = numpy.array([0, 0]) + return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0) + + @classmethod + def createRect(cls, startPoint, endPoint, weight): + """Create an arc geometry from a definition of a rectangle + """ + return cls(None, startPoint, endPoint, None, weight, None, None, False) + + @classmethod + def createCircle(cls, center, startPoint, endPoint, radius, + weight, startAngle, endAngle): + """Create an arc geometry from a definition of a circle + """ + return cls(center, startPoint, endPoint, radius, + weight, startAngle, endAngle, True) + + def withWeight(self, weight): + """Return a new geometry based on this object, with a specific weight + """ + return _ArcGeometry(self.center, self.startPoint, self.endPoint, + self.radius, weight, + self.startAngle, self.endAngle, self._closed) + + def withRadius(self, radius): + """Return a new geometry based on this object, with a specific radius. + + The weight and the center is conserved. + """ + startPoint = self.center + (self.startPoint - self.center) / self.radius * radius + endPoint = self.center + (self.endPoint - self.center) / self.radius * radius + return _ArcGeometry(self.center, startPoint, endPoint, + radius, self.weight, + self.startAngle, self.endAngle, self._closed) + + def withStartAngle(self, startAngle): + """Return a new geometry based on this object, with a specific start angle + """ + vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)]) + startPoint = self.center + vector * self.radius + + # Never add more than 180 to maintain coherency + deltaAngle = startAngle - self.startAngle + if deltaAngle > numpy.pi: + deltaAngle -= numpy.pi * 2 + elif deltaAngle < -numpy.pi: + deltaAngle += numpy.pi * 2 + + startAngle = self.startAngle + deltaAngle + return _ArcGeometry( + self.center, + startPoint, + self.endPoint, + self.radius, + self.weight, + startAngle, + self.endAngle, + self._closed, + ) + + def withEndAngle(self, endAngle): + """Return a new geometry based on this object, with a specific end angle + """ + vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)]) + endPoint = self.center + vector * self.radius + + # Never add more than 180 to maintain coherency + deltaAngle = endAngle - self.endAngle + if deltaAngle > numpy.pi: + deltaAngle -= numpy.pi * 2 + elif deltaAngle < -numpy.pi: + deltaAngle += numpy.pi * 2 + + endAngle = self.endAngle + deltaAngle + return _ArcGeometry( + self.center, + self.startPoint, + endPoint, + self.radius, + self.weight, + self.startAngle, + endAngle, + self._closed, + ) + + def translated(self, dx, dy): + """Return the translated geometry by dx, dy""" + delta = numpy.array([dx, dy]) + center = None if self.center is None else self.center + delta + startPoint = None if self.startPoint is None else self.startPoint + delta + endPoint = None if self.endPoint is None else self.endPoint + delta + return _ArcGeometry(center, startPoint, endPoint, + self.radius, self.weight, + self.startAngle, self.endAngle, self._closed) + + def getKind(self): + """Returns the kind of shape defined""" + if self.center is None: + return "rect" + elif numpy.isnan(self.startAngle): + return "point" + elif self.isClosed(): + if self.weight <= 0 or self.weight * 0.5 >= self.radius: + return "circle" + else: + return "donut" + else: + if self.weight * 0.5 < self.radius: + return "arc" + else: + return "camembert" + + def isClosed(self): + """Returns True if the geometry is a circle like""" + if self._closed is not None: + return self._closed + delta = numpy.abs(self.endAngle - self.startAngle) + self._closed = numpy.isclose(delta, numpy.pi * 2) + return self._closed + + def __str__(self): + return str((self.center, + self.startPoint, + self.endPoint, + self.radius, + self.weight, + self.startAngle, + self.endAngle, + self._closed)) + + +class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn): + """A ROI identifying an arc of a circle with a width. + + This ROI provides + - 3 handle to control the curvature + - 1 handle to control the weight + - 1 anchor to translate the shape. + """ + + ICON = 'add-shape-arc' + NAME = 'arc ROI' + SHORT_NAME = "arc" + """Metadata for this kind of ROI""" + + _plotShape = "line" + """Plot shape which is used for the first interaction""" + + ThreePointMode = RoiInteractionMode("3 points", "Provides 3 points to define the main radius circle") + PolarMode = RoiInteractionMode("Polar", "Provides anchors to edit the ROI in polar coords") + # FIXME: MoveMode was designed cause there is too much anchors + # FIXME: It would be good replace it by a dnd on the shape + MoveMode = RoiInteractionMode("Translation", "Provides anchors to only move the ROI") + + def __init__(self, parent=None): + HandleBasedROI.__init__(self, parent=parent) + items.LineMixIn.__init__(self) + InteractionModeMixIn.__init__(self) + + self._geometry = _ArcGeometry.createEmpty() + self._handleLabel = self.addLabelHandle() + + self._handleStart = self.addHandle() + self._handleMid = self.addHandle() + self._handleEnd = self.addHandle() + self._handleWeight = self.addHandle() + self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint) + self._handleMove = self.addTranslateHandle() + + shape = items.Shape("polygon") + shape.setPoints([[0, 0], [0, 0]]) + shape.setColor(rgba(self.getColor())) + shape.setFill(False) + shape.setOverlay(True) + shape.setLineStyle(self.getLineStyle()) + shape.setLineWidth(self.getLineWidth()) + self.__shape = shape + self.addItem(shape) + + self._initInteractionMode(self.ThreePointMode) + self._interactiveModeUpdated(self.ThreePointMode) + + def availableInteractionModes(self): + """Returns the list of available interaction modes + + :rtype: List[RoiInteractionMode] + """ + return [self.ThreePointMode, self.PolarMode, self.MoveMode] + + def _interactiveModeUpdated(self, modeId): + """Set the interaction mode. + + :param RoiInteractionMode modeId: + """ + if modeId is self.ThreePointMode: + self._handleStart.setSymbol("s") + self._handleMid.setSymbol("s") + self._handleEnd.setSymbol("s") + self._handleWeight.setSymbol("d") + self._handleMove.setSymbol("+") + elif modeId is self.PolarMode: + self._handleStart.setSymbol("o") + self._handleMid.setSymbol("o") + self._handleEnd.setSymbol("o") + self._handleWeight.setSymbol("d") + self._handleMove.setSymbol("+") + elif modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleMid.setSymbol("+") + self._handleEnd.setSymbol("") + self._handleWeight.setSymbol("") + self._handleMove.setSymbol("+") + else: + assert False + if self._geometry.isClosed(): + if modeId != self.MoveMode: + self._handleStart.setSymbol("x") + self._handleEnd.setSymbol("x") + self._updateHandles() + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.VISIBLE: + self._updateItemProperty(event, self, self.__shape) + super(ArcROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(ArcROI, self)._updatedStyle(event, style) + self.__shape.setColor(style.getColor()) + self.__shape.setLineStyle(style.getLineStyle()) + self.__shape.setLineWidth(style.getLineWidth()) + + def setFirstShapePoints(self, points): + """"Initialize the ROI using the points from the first interaction. + + This interaction is constrained by the plot API and only supports few + shapes. + """ + # The first shape is a line + point0 = points[0] + point1 = points[1] + + # Compute a non collinear point for the curvature + center = (point1 + point0) * 0.5 + normal = point1 - center + normal = numpy.array((normal[1], -normal[0])) + defaultCurvature = numpy.pi / 5.0 + weightCoef = 0.20 + mid = center - normal * defaultCurvature + distance = numpy.linalg.norm(point0 - point1) + weight = distance * weightCoef + + geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight) + self._geometry = geometry + self._updateHandles() + + def _updateText(self, text): + self._handleLabel.setText(text) + + def _updateMidHandle(self): + """Keep the same geometry, but update the location of the control + points. + + So calling this function do not trigger sigRegionChanged. + """ + geometry = self._geometry + + if geometry.isClosed(): + start = numpy.array(self._handleStart.getPosition()) + midPos = geometry.center + geometry.center - start + else: + if geometry.center is None: + midPos = geometry.startPoint * 0.5 + geometry.endPoint * 0.5 + else: + midAngle = geometry.startAngle * 0.5 + geometry.endAngle * 0.5 + vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + midPos = geometry.center + geometry.radius * vector + + with utils.blockSignals(self._handleMid): + self._handleMid.setPosition(*midPos) + + def _updateWeightHandle(self): + geometry = self._geometry + if geometry.center is None: + # rectangle + center = (geometry.startPoint + geometry.endPoint) * 0.5 + normal = geometry.endPoint - geometry.startPoint + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal = normal / distance + weightPos = center + normal * geometry.weight * 0.5 + else: + if geometry.isClosed(): + midAngle = geometry.startAngle + numpy.pi * 0.5 + elif geometry.center is not None: + midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 + vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector + + with utils.blockSignals(self._handleWeight): + self._handleWeight.setPosition(*weightPos) + + def _getWeightFromHandle(self, weightPos): + geometry = self._geometry + if geometry.center is None: + # rectangle + center = (geometry.startPoint + geometry.endPoint) * 0.5 + return numpy.linalg.norm(center - weightPos) * 2 + else: + distance = numpy.linalg.norm(geometry.center - weightPos) + return abs(distance - geometry.radius) * 2 + + def _updateHandles(self): + geometry = self._geometry + with utils.blockSignals(self._handleStart): + self._handleStart.setPosition(*geometry.startPoint) + with utils.blockSignals(self._handleEnd): + self._handleEnd.setPosition(*geometry.endPoint) + + self._updateMidHandle() + self._updateWeightHandle() + self._updateShape() + + def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False): + """Update the curvature using 3 control points in the curve + + :param bool updateCurveHandles: If False curve handles are already at + the right location + """ + if checkClosed: + closed = self._isCloseInPixel(start, end) + else: + closed = self._geometry.isClosed() + if closed: + if updateStart: + start = end + else: + end = start + + if updateCurveHandles: + with utils.blockSignals(self._handleStart): + self._handleStart.setPosition(*start) + with utils.blockSignals(self._handleMid): + self._handleMid.setPosition(*mid) + with utils.blockSignals(self._handleEnd): + self._handleEnd.setPosition(*end) + + weight = self._geometry.weight + geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed) + self._geometry = geometry + + self._updateWeightHandle() + self._updateShape() + + def _updateCloseInAngle(self, geometry, updateStart): + azim = numpy.abs(geometry.endAngle - geometry.startAngle) + if numpy.pi < azim < 3 * numpy.pi: + closed = self._isCloseInPixel(geometry.startPoint, geometry.endPoint) + geometry._closed = closed + if closed: + sign = 1 if geometry.startAngle < geometry.endAngle else -1 + if updateStart: + geometry.startPoint = geometry.endPoint + geometry.startAngle = geometry.endAngle - sign * 2*numpy.pi + else: + geometry.endPoint = geometry.startPoint + geometry.endAngle = geometry.startAngle + sign * 2*numpy.pi + + def handleDragUpdated(self, handle, origin, previous, current): + modeId = self.getInteractionMode() + if handle is self._handleStart: + if modeId is self.ThreePointMode: + mid = numpy.array(self._handleMid.getPosition()) + end = numpy.array(self._handleEnd.getPosition()) + self._updateCurvature( + current, mid, end, checkClosed=True, updateStart=True, + updateCurveHandles=False + ) + elif modeId is self.PolarMode: + v = current - self._geometry.center + startAngle = numpy.angle(complex(v[0], v[1])) + geometry = self._geometry.withStartAngle(startAngle) + self._updateCloseInAngle(geometry, updateStart=True) + self._geometry = geometry + self._updateHandles() + elif handle is self._handleMid: + if modeId is self.ThreePointMode: + if self._geometry.isClosed(): + radius = numpy.linalg.norm(self._geometry.center - current) + self._geometry = self._geometry.withRadius(radius) + self._updateHandles() + else: + start = numpy.array(self._handleStart.getPosition()) + end = numpy.array(self._handleEnd.getPosition()) + self._updateCurvature(start, current, end, updateCurveHandles=False) + elif modeId is self.PolarMode: + radius = numpy.linalg.norm(self._geometry.center - current) + self._geometry = self._geometry.withRadius(radius) + self._updateHandles() + elif modeId is self.MoveMode: + delta = current - previous + self.translate(*delta) + elif handle is self._handleEnd: + if modeId is self.ThreePointMode: + start = numpy.array(self._handleStart.getPosition()) + mid = numpy.array(self._handleMid.getPosition()) + self._updateCurvature( + start, mid, current, checkClosed=True, updateStart=False, + updateCurveHandles=False + ) + elif modeId is self.PolarMode: + v = current - self._geometry.center + endAngle = numpy.angle(complex(v[0], v[1])) + geometry = self._geometry.withEndAngle(endAngle) + self._updateCloseInAngle(geometry, updateStart=False) + self._geometry = geometry + self._updateHandles() + elif handle is self._handleWeight: + weight = self._getWeightFromHandle(current) + self._geometry = self._geometry.withWeight(weight) + self._updateShape() + elif handle is self._handleMove: + delta = current - previous + self.translate(*delta) + + def _isCloseInPixel(self, point1, point2): + manager = self.parent() + if manager is None: + return False + plot = manager.parent() + if plot is None: + return False + point1 = plot.dataToPixel(*point1) + if point1 is None: + return False + point2 = plot.dataToPixel(*point2) + if point2 is None: + return False + return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15 + + def _normalizeGeometry(self): + """Keep the same phisical geometry, but with normalized parameters. + """ + geometry = self._geometry + if geometry.weight * 0.5 >= geometry.radius: + radius = (geometry.weight * 0.5 + geometry.radius) * 0.5 + geometry = geometry.withRadius(radius) + geometry = geometry.withWeight(radius * 2) + self._geometry = geometry + return True + return False + + def handleDragFinished(self, handle, origin, current): + modeId = self.getInteractionMode() + if handle in [self._handleStart, self._handleMid, self._handleEnd]: + if modeId is self.ThreePointMode: + self._normalizeGeometry() + self._updateHandles() + + if self._geometry.isClosed(): + if modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleEnd.setSymbol("") + else: + self._handleStart.setSymbol("x") + self._handleEnd.setSymbol("x") + else: + if modeId is self.ThreePointMode: + self._handleStart.setSymbol("s") + self._handleEnd.setSymbol("s") + elif modeId is self.PolarMode: + self._handleStart.setSymbol("o") + self._handleEnd.setSymbol("o") + if modeId is self.MoveMode: + self._handleStart.setSymbol("") + self._handleEnd.setSymbol("") + + def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None): + """Returns the geometry of the object""" + if closed or (closed is None and numpy.allclose(start, end)): + # Special arc: It's a closed circle + center = (start + mid) * 0.5 + radius = numpy.linalg.norm(start - center) + v = start - center + startAngle = numpy.angle(complex(v[0], v[1])) + endAngle = startAngle + numpy.pi * 2.0 + return _ArcGeometry.createCircle( + center, start, end, radius, weight, startAngle, endAngle + ) + + elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5: + # Degenerated arc, it's a rectangle + return _ArcGeometry.createRect(start, end, weight) + else: + center, radius = self._circleEquation(start, mid, end) + v = start - center + startAngle = numpy.angle(complex(v[0], v[1])) + v = mid - center + midAngle = numpy.angle(complex(v[0], v[1])) + v = end - center + endAngle = numpy.angle(complex(v[0], v[1])) + + # Is it clockwise or anticlockwise + relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi) + relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi) + if relativeMid < relativeEnd: + if endAngle < startAngle: + endAngle += 2 * numpy.pi + else: + if endAngle > startAngle: + endAngle -= 2 * numpy.pi + + return _ArcGeometry(center, start, end, + radius, weight, startAngle, endAngle) + + def _createShapeFromGeometry(self, geometry): + kind = geometry.getKind() + if kind == "rect": + # It is not an arc + # but we can display it as an intermediate shape + normal = geometry.endPoint - geometry.startPoint + normal = numpy.array((normal[1], -normal[0])) + distance = numpy.linalg.norm(normal) + if distance != 0: + normal /= distance + points = numpy.array([ + geometry.startPoint + normal * geometry.weight * 0.5, + geometry.endPoint + normal * geometry.weight * 0.5, + geometry.endPoint - normal * geometry.weight * 0.5, + geometry.startPoint - normal * geometry.weight * 0.5]) + elif kind == "point": + # It is not an arc + # but we can display it as an intermediate shape + # NOTE: At least 2 points are expected + points = numpy.array([geometry.startPoint, geometry.startPoint]) + elif kind == "circle": + outerRadius = geometry.radius + geometry.weight * 0.5 + angles = numpy.linspace(0, 2 * numpy.pi, num=50) + # It's a circle + points = [] + numpy.append(angles, angles[-1]) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + points = numpy.array(points) + elif kind == "donut": + innerRadius = geometry.radius - geometry.weight * 0.5 + outerRadius = geometry.radius + geometry.weight * 0.5 + angles = numpy.linspace(0, 2 * numpy.pi, num=50) + # It's a donut + points = [] + # NOTE: NaN value allow to create 2 separated circle shapes + # using a single plot item. It's a kind of cheat + points.append(numpy.array([float("nan"), float("nan")])) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.append(numpy.array([float("nan"), float("nan")])) + points = numpy.array(points) + else: + innerRadius = geometry.radius - geometry.weight * 0.5 + outerRadius = geometry.radius + geometry.weight * 0.5 + + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + if geometry.startAngle == geometry.endAngle: + # Degenerated, it's a line (single radius) + angle = geometry.startAngle + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points = [] + points.append(geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + return numpy.array(points) + + angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta) + if angles[-1] != geometry.endAngle: + angles = numpy.append(angles, geometry.endAngle) + + if kind == "camembert": + # It's a part of camembert + points = [] + points.append(geometry.center) + points.append(geometry.startPoint) + delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.append(geometry.center + direction * outerRadius) + points.append(geometry.endPoint) + points.append(geometry.center) + elif kind == "arc": + # It's a part of donut + points = [] + points.append(geometry.startPoint) + for angle in angles: + direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) + points.insert(0, geometry.center + direction * innerRadius) + points.append(geometry.center + direction * outerRadius) + points.insert(0, geometry.endPoint) + points.append(geometry.endPoint) + else: + assert False + + points = numpy.array(points) + + return points + + def _updateShape(self): + geometry = self._geometry + points = self._createShapeFromGeometry(geometry) + self.__shape.setPoints(points) + + index = numpy.nanargmin(points[:, 1]) + pos = points[index] + with utils.blockSignals(self._handleLabel): + self._handleLabel.setPosition(pos[0], pos[1]) + + if geometry.center is None: + movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66 + else: + movePos = geometry.center + + with utils.blockSignals(self._handleMove): + self._handleMove.setPosition(*movePos) + + self.sigRegionChanged.emit() + + def getGeometry(self): + """Returns a tuple containing the geometry of this ROI + + It is a symmetric function of :meth:`setGeometry`. + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: Tuple[numpy.ndarray,float,float,float,float] + :raise ValueError: In case the ROI can't be represented as section of + a circle + """ + geometry = self._geometry + if geometry.center is None: + raise ValueError("This ROI can't be represented as a section of circle") + return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle + + def isClosed(self): + """Returns true if the arc is a closed shape, like a circle or a donut. + + :rtype: bool + """ + return self._geometry.isClosed() + + def getCenter(self): + """Returns the center of the circle used to draw arcs of this ROI. + + This center is usually outside the the shape itself. + + :rtype: numpy.ndarray + """ + return self._geometry.center + + def getStartAngle(self): + """Returns the angle of the start of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + return self._geometry.startAngle + + def getEndAngle(self): + """Returns the angle of the end of the section of this ROI (in radian). + + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + + :rtype: float + """ + return self._geometry.endAngle + + def getInnerRadius(self): + """Returns the radius of the smaller arc used to draw this ROI. + + :rtype: float + """ + geometry = self._geometry + radius = geometry.radius - geometry.weight * 0.5 + if radius < 0: + radius = 0 + return radius + + def getOuterRadius(self): + """Returns the radius of the bigger arc used to draw this ROI. + + :rtype: float + """ + geometry = self._geometry + radius = geometry.radius + geometry.weight * 0.5 + return radius + + def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle): + """ + Set the geometry of this arc. + + :param numpy.ndarray center: Center of the circle. + :param float innerRadius: Radius of the smaller arc of the section. + :param float outerRadius: Weight of the bigger arc of the section. + It have to be bigger than `innerRadius` + :param float startAngle: Location of the start of the section (in radian) + :param float endAngle: Location of the end of the section (in radian). + If `startAngle` is smaller than `endAngle` the rotation is clockwise, + else the rotation is anticlockwise. + """ + assert innerRadius <= outerRadius + assert numpy.abs(startAngle - endAngle) <= 2 * numpy.pi + center = numpy.array(center) + radius = (innerRadius + outerRadius) * 0.5 + weight = outerRadius - innerRadius + + vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)]) + startPoint = center + vector * radius + vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)]) + endPoint = center + vector * radius + + geometry = _ArcGeometry(center, startPoint, endPoint, + radius, weight, + startAngle, endAngle, closed=None) + self._geometry = geometry + self._updateHandles() + + @docstring(HandleBasedROI) + def contains(self, position): + # first check distance, fastest + center = self.getCenter() + distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2) + is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius() + if not is_in_distance: + return False + rel_pos = position[1] - center[1], position[0] - center[0] + angle = numpy.arctan2(*rel_pos) + # angle is inside [-pi, pi] + + # Normalize the start angle between [-pi, pi] + # with a positive angle range + start_angle = self.getStartAngle() + end_angle = self.getEndAngle() + azim_range = end_angle - start_angle + if azim_range < 0: + start_angle = end_angle + azim_range = -azim_range + start_angle = numpy.mod(start_angle + numpy.pi, 2 * numpy.pi) - numpy.pi + + if angle < start_angle: + angle += 2 * numpy.pi + return start_angle <= angle <= start_angle + azim_range + + def translate(self, x, y): + self._geometry = self._geometry.translated(x, y) + self._updateHandles() + + def _arcCurvatureMarkerConstraint(self, x, y): + """Curvature marker remains on perpendicular bisector""" + geometry = self._geometry + if geometry.center is None: + center = (geometry.startPoint + geometry.endPoint) * 0.5 + vector = geometry.startPoint - geometry.endPoint + vector = numpy.array((vector[1], -vector[0])) + vdist = numpy.linalg.norm(vector) + if vdist != 0: + normal = numpy.array((vector[1], -vector[0])) / vdist + else: + normal = numpy.array((0, 0)) + else: + if geometry.isClosed(): + midAngle = geometry.startAngle + numpy.pi * 0.5 + else: + midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 + normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) + center = geometry.center + dist = numpy.dot(normal, (numpy.array((x, y)) - center)) + dist = numpy.clip(dist, geometry.radius, geometry.radius * 2) + x, y = center + dist * normal + return x, y + + @staticmethod + def _circleEquation(pt1, pt2, pt3): + """Circle equation from 3 (x, y) points + + :return: Position of the center of the circle and the radius + :rtype: Tuple[Tuple[float,float],float] + """ + x, y, z = complex(*pt1), complex(*pt2), complex(*pt3) + w = z - x + w /= y - x + c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x + return numpy.array((-c.real, -c.imag)), abs(c + x) + + def __str__(self): + try: + center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry() + params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle + params = 'center: %f %f; radius: %f %f; angles: %f %f' % params + except ValueError: + params = "invalid" + return "%s(%s)" % (self.__class__.__name__, params) diff --git a/silx/gui/plot/items/_pick.py b/silx/gui/plot/items/_pick.py index 4ddf4f6..8c8e781 100644 --- a/silx/gui/plot/items/_pick.py +++ b/silx/gui/plot/items/_pick.py @@ -48,7 +48,7 @@ class PickingResult(object): self._indices = None else: # Indices is set to None if indices array is empty - indices = numpy.array(indices, copy=False, dtype=numpy.int) + indices = numpy.array(indices, copy=False, dtype=numpy.int64) self._indices = None if indices.size == 0 else indices def getItem(self): diff --git a/silx/gui/plot/items/_roi_base.py b/silx/gui/plot/items/_roi_base.py new file mode 100644 index 0000000..3eb6cf4 --- /dev/null +++ b/silx/gui/plot/items/_roi_base.py @@ -0,0 +1,835 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module provides base components to create ROI item for +the :class:`~silx.gui.plot.PlotWidget`. + +.. inheritance-diagram:: + silx.gui.plot.items.roi + :parts: 1 +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "28/06/2018" + + +import logging +import numpy +import weakref + +from ....utils.weakref import WeakList +from ... import qt +from .. import items +from ..items import core +from ...colors import rgba +import silx.utils.deprecation +from ....utils.proxy import docstring + + +logger = logging.getLogger(__name__) + + +class _RegionOfInterestBase(qt.QObject): + """Base class of 1D and 2D region of interest + + :param QObject parent: See QObject + :param str name: The name of the ROI + """ + + sigAboutToBeRemoved = qt.Signal() + """Signal emitted just before this ROI is removed from its manager.""" + + sigItemChanged = qt.Signal(object) + """Signal emitted when item has changed. + + It provides a flag describing which property of the item has changed. + See :class:`ItemChangedType` for flags description. + """ + + def __init__(self, parent=None): + qt.QObject.__init__(self, parent=parent) + self.__name = '' + + def getName(self): + """Returns the name of the ROI + + :return: name of the region of interest + :rtype: str + """ + return self.__name + + def setName(self, name): + """Set the name of the ROI + + :param str name: name of the region of interest + """ + name = str(name) + if self.__name != name: + self.__name = name + self._updated(items.ItemChangedType.NAME) + + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + self.sigItemChanged.emit(event) + + def contains(self, position): + """Returns True if the `position` is in this ROI. + + :param tuple[float,float] position: position to check + :return: True if the value / point is consider to be in the region of + interest. + :rtype: bool + """ + return False # Override in subclass to perform actual test + + +class RoiInteractionMode(object): + """Description of an interaction mode. + + An interaction mode provide a specific kind of interaction for a ROI. + A ROI can implement many interaction. + """ + + def __init__(self, label, description=None): + self._label = label + self._description = description + + @property + def label(self): + return self._label + + @property + def description(self): + return self._description + + +class InteractionModeMixIn(object): + """Mix in feature which can be implemented by a ROI object. + + This provides user interaction to switch between different + interaction mode to edit the ROI. + + This ROI modes have to be described using `RoiInteractionMode`, + and taken into account during interation with handles. + """ + + sigInteractionModeChanged = qt.Signal(object) + + def __init__(self): + self.__modeId = None + + def _initInteractionMode(self, modeId): + """Set the mode without updating anything. + + Must be one of the returned :meth:`availableInteractionModes`. + + :param RoiInteractionMode modeId: Mode to use + """ + self.__modeId = modeId + + def availableInteractionModes(self): + """Returns the list of available interaction modes + + Must be implemented when inherited to provide all available modes. + + :rtype: List[RoiInteractionMode] + """ + raise NotImplementedError() + + def setInteractionMode(self, modeId): + """Set the interaction mode. + + :param RoiInteractionMode modeId: Mode to use + """ + self.__modeId = modeId + self._interactiveModeUpdated(modeId) + self.sigInteractionModeChanged.emit(modeId) + + def _interactiveModeUpdated(self, modeId): + """Called directly after an update of the mode. + + The signal `sigInteractionModeChanged` is triggered after this + call. + + Must be implemented when inherited to take care of the change. + """ + raise NotImplementedError() + + def getInteractionMode(self): + """Returns the interaction mode. + + Must be one of the returned :meth:`availableInteractionModes`. + + :rtype: RoiInteractionMode + """ + return self.__modeId + + +class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn): + """Object describing a region of interest in a plot. + + :param QObject parent: + The RegionOfInterestManager that created this object + """ + + _DEFAULT_LINEWIDTH = 1. + """Default line width of the curve""" + + _DEFAULT_LINESTYLE = '-' + """Default line style of the curve""" + + _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2) + """Default highlight style of the item""" + + ICON, NAME, SHORT_NAME = None, None, None + """Metadata to describe the ROI in labels, tooltips and widgets + + Should be set by inherited classes to custom the ROI manager widget. + """ + + sigRegionChanged = qt.Signal() + """Signal emitted everytime the shape or position of the ROI changes""" + + sigEditingStarted = qt.Signal() + """Signal emitted when the user start editing the roi""" + + sigEditingFinished = qt.Signal() + """Signal emitted when the region edition is finished. During edition + sigEditionChanged will be emitted several times and + sigRegionEditionFinished only at end""" + + def __init__(self, parent=None): + # Avoid circular dependency + from ..tools import roi as roi_tools + assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) + _RegionOfInterestBase.__init__(self, parent) + core.HighlightedMixIn.__init__(self) + self._color = rgba('red') + self._editable = False + self._selectable = False + self._focusProxy = None + self._visible = True + self._child = WeakList() + + def _connectToPlot(self, plot): + """Called after connection to a plot""" + for item in self.getItems(): + # This hack is needed to avoid reentrant call from _disconnectFromPlot + # to the ROI manager. It also speed up the item tests in _itemRemoved + item._roiGroup = True + plot.addItem(item) + + def _disconnectFromPlot(self, plot): + """Called before disconnection from a plot""" + for item in self.getItems(): + # The item could be already be removed by the plot + if item.getPlot() is not None: + del item._roiGroup + plot.removeItem(item) + + def _setItemName(self, item): + """Helper to generate a unique id to a plot item""" + legend = "__ROI-%d__%d" % (id(self), id(item)) + item.setName(legend) + + def setParent(self, parent): + """Set the parent of the RegionOfInterest + + :param Union[None,RegionOfInterestManager] parent: The new parent + """ + # Avoid circular dependency + from ..tools import roi as roi_tools + if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)): + raise ValueError('Unsupported parent') + + previousParent = self.parent() + if previousParent is not None: + previousPlot = previousParent.parent() + if previousPlot is not None: + self._disconnectFromPlot(previousPlot) + super(RegionOfInterest, self).setParent(parent) + if parent is not None: + plot = parent.parent() + if plot is not None: + self._connectToPlot(plot) + + def addItem(self, item): + """Add an item to the set of this ROI children. + + This item will be added and removed to the plot used by the ROI. + + If the ROI is already part of a plot, the item will also be added to + the plot. + + It the item do not have a name already, a unique one is generated to + avoid item collision in the plot. + + :param silx.gui.plot.items.Item item: A plot item + """ + assert item is not None + self._child.append(item) + if item.getName() == '': + self._setItemName(item) + manager = self.parent() + if manager is not None: + plot = manager.parent() + if plot is not None: + item._roiGroup = True + plot.addItem(item) + + def removeItem(self, item): + """Remove an item from this ROI children. + + If the item is part of a plot it will be removed too. + + :param silx.gui.plot.items.Item item: A plot item + """ + assert item is not None + self._child.remove(item) + plot = item.getPlot() + if plot is not None: + del item._roiGroup + plot.removeItem(item) + + def getItems(self): + """Returns the list of PlotWidget items of this RegionOfInterest. + + :rtype: List[~silx.gui.plot.items.Item] + """ + return tuple(self._child) + + @classmethod + def _getShortName(cls): + """Return an human readable kind of ROI + + :rtype: str + """ + if hasattr(cls, "SHORT_NAME"): + name = cls.SHORT_NAME + if name is None: + name = cls.__name__ + return name + + def getColor(self): + """Returns the color of this ROI + + :rtype: QColor + """ + return qt.QColor.fromRgbF(*self._color) + + def setColor(self, color): + """Set the color used for this ROI. + + :param color: The color to use for ROI shape as + either a color name, a QColor, a list of uint8 or float in [0, 1]. + """ + color = rgba(color) + if color != self._color: + self._color = color + self._updated(items.ItemChangedType.COLOR) + + @silx.utils.deprecation.deprecated(reason='API modification', + replacement='getName()', + since_version=0.12) + def getLabel(self): + """Returns the label displayed for this ROI. + + :rtype: str + """ + return self.getName() + + @silx.utils.deprecation.deprecated(reason='API modification', + replacement='setName(name)', + since_version=0.12) + def setLabel(self, label): + """Set the label displayed with this ROI. + + :param str label: The text label to display + """ + self.setName(name=label) + + def isEditable(self): + """Returns whether the ROI is editable by the user or not. + + :rtype: bool + """ + return self._editable + + def setEditable(self, editable): + """Set whether the ROI can be changed interactively. + + :param bool editable: True to allow edition by the user, + False to disable. + """ + editable = bool(editable) + if self._editable != editable: + self._editable = editable + self._updated(items.ItemChangedType.EDITABLE) + + def isSelectable(self): + """Returns whether the ROI is selectable by the user or not. + + :rtype: bool + """ + return self._selectable + + def setSelectable(self, selectable): + """Set whether the ROI can be selected interactively. + + :param bool selectable: True to allow selection by the user, + False to disable. + """ + selectable = bool(selectable) + if self._selectable != selectable: + self._selectable = selectable + self._updated(items.ItemChangedType.SELECTABLE) + + def getFocusProxy(self): + """Returns the ROI which have to be selected when this ROI is selected, + else None if no proxy specified. + + :rtype: RegionOfInterest + """ + proxy = self._focusProxy + if proxy is None: + return None + proxy = proxy() + if proxy is None: + self._focusProxy = None + return proxy + + def setFocusProxy(self, roi): + """Set the real ROI which will be selected when this ROI is selected, + else None to remove the proxy already specified. + + :param RegionOfInterest roi: A ROI + """ + if roi is not None: + self._focusProxy = weakref.ref(roi) + else: + self._focusProxy = None + + def isVisible(self): + """Returns whether the ROI is visible in the plot. + + .. note:: + This does not take into account whether or not the plot + widget itself is visible (unlike :meth:`QWidget.isVisible` which + checks the visibility of all its parent widgets up to the window) + + :rtype: bool + """ + return self._visible + + def setVisible(self, visible): + """Set whether the plot items associated with this ROI are + visible in the plot. + + :param bool visible: True to show the ROI in the plot, False to + hide it. + """ + visible = bool(visible) + if self._visible != visible: + self._visible = visible + self._updated(items.ItemChangedType.VISIBLE) + + @classmethod + def showFirstInteractionShape(cls): + """Returns True if the shape created by the first interaction and + managed by the plot have to be visible. + + :rtype: bool + """ + return False + + @classmethod + def getFirstInteractionShape(cls): + """Returns the shape kind which will be used by the very first + interaction with the plot. + + This interactions are hardcoded inside the plot + + :rtype: str + """ + return cls._plotShape + + def setFirstShapePoints(self, points): + """"Initialize the ROI using the points from the first interaction. + + This interaction is constrained by the plot API and only supports few + shapes. + """ + raise NotImplementedError() + + def creationStarted(self): + """"Called when the ROI creation interaction was started. + """ + pass + + def creationFinalized(self): + """"Called when the ROI creation interaction was finalized. + """ + pass + + def _updateItemProperty(self, event, source, destination): + """Update the item property of a destination from an item source. + + :param items.ItemChangedType event: Property type to update + :param silx.gui.plot.items.Item source: The reference for the data + :param event Union[Item,List[Item]] destination: The item(s) to update + """ + if not isinstance(destination, (list, tuple)): + destination = [destination] + if event == items.ItemChangedType.NAME: + value = source.getName() + for d in destination: + d.setName(value) + elif event == items.ItemChangedType.EDITABLE: + value = source.isEditable() + for d in destination: + d.setEditable(value) + elif event == items.ItemChangedType.SELECTABLE: + value = source.isSelectable() + for d in destination: + d._setSelectable(value) + elif event == items.ItemChangedType.COLOR: + value = rgba(source.getColor()) + for d in destination: + d.setColor(value) + elif event == items.ItemChangedType.LINE_STYLE: + value = self.getLineStyle() + for d in destination: + d.setLineStyle(value) + elif event == items.ItemChangedType.LINE_WIDTH: + value = self.getLineWidth() + for d in destination: + d.setLineWidth(value) + elif event == items.ItemChangedType.SYMBOL: + value = self.getSymbol() + for d in destination: + d.setSymbol(value) + elif event == items.ItemChangedType.SYMBOL_SIZE: + value = self.getSymbolSize() + for d in destination: + d.setSymbolSize(value) + elif event == items.ItemChangedType.VISIBLE: + value = self.isVisible() + for d in destination: + d.setVisible(value) + else: + assert False + + def _updated(self, event=None, checkVisibility=True): + if event == items.ItemChangedType.HIGHLIGHTED: + style = self.getCurrentStyle() + self._updatedStyle(event, style) + else: + styleEvents = [items.ItemChangedType.COLOR, + items.ItemChangedType.LINE_STYLE, + items.ItemChangedType.LINE_WIDTH, + items.ItemChangedType.SYMBOL, + items.ItemChangedType.SYMBOL_SIZE] + if self.isHighlighted(): + styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE) + + if event in styleEvents: + style = self.getCurrentStyle() + self._updatedStyle(event, style) + + super(RegionOfInterest, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + """Called when the current displayed style of the ROI was changed. + + :param event: The event responsible of the change of the style + :param items.CurveStyle style: The current style + """ + pass + + def getCurrentStyle(self): + """Returns the current curve style. + + Curve style depends on curve highlighting + + :rtype: CurveStyle + """ + baseColor = rgba(self.getColor()) + if isinstance(self, core.LineMixIn): + baseLinestyle = self.getLineStyle() + baseLinewidth = self.getLineWidth() + else: + baseLinestyle = self._DEFAULT_LINESTYLE + baseLinewidth = self._DEFAULT_LINEWIDTH + if isinstance(self, core.SymbolMixIn): + baseSymbol = self.getSymbol() + baseSymbolsize = self.getSymbolSize() + else: + baseSymbol = 'o' + baseSymbolsize = 1 + + if self.isHighlighted(): + style = self.getHighlightedStyle() + color = style.getColor() + linestyle = style.getLineStyle() + linewidth = style.getLineWidth() + symbol = style.getSymbol() + symbolsize = style.getSymbolSize() + + return items.CurveStyle( + color=baseColor if color is None else color, + linestyle=baseLinestyle if linestyle is None else linestyle, + linewidth=baseLinewidth if linewidth is None else linewidth, + symbol=baseSymbol if symbol is None else symbol, + symbolsize=baseSymbolsize if symbolsize is None else symbolsize) + else: + return items.CurveStyle(color=baseColor, + linestyle=baseLinestyle, + linewidth=baseLinewidth, + symbol=baseSymbol, + symbolsize=baseSymbolsize) + + def _editingStarted(self): + assert self._editable is True + self.sigEditingStarted.emit() + + def _editingFinished(self): + self.sigEditingFinished.emit() + + +class HandleBasedROI(RegionOfInterest): + """Manage a ROI based on a set of handles""" + + def __init__(self, parent=None): + RegionOfInterest.__init__(self, parent=parent) + self._handles = [] + self._posOrigin = None + self._posPrevious = None + + def addUserHandle(self, item=None): + """ + Add a new free handle to the ROI. + + This handle do nothing. It have to be managed by the ROI + implementing this class. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="user") + + def addLabelHandle(self, item=None): + """ + Add a new label handle to the ROI. + + This handle is not draggable nor selectable. + + It is displayed without symbol, but it is always visible anyway + the ROI is editable, in order to display text. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="label") + + def addTranslateHandle(self, item=None): + """ + Add a new translate handle to the ROI. + + Dragging translate handles affect the position position of the ROI + but not the shape itself. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + return self.addHandle(item, role="translate") + + def addHandle(self, item=None, role="default"): + """ + Add a new handle to the ROI. + + Dragging handles while affect the position or the shape of the + ROI. + + :param Union[None,silx.gui.plot.items.Marker] item: The new marker to + add, else None to create a default marker. + :rtype: silx.gui.plot.items.Marker + """ + if item is None: + item = items.Marker() + color = rgba(self.getColor()) + color = self._computeHandleColor(color) + item.setColor(color) + if role == "default": + item.setSymbol("s") + elif role == "user": + pass + elif role == "translate": + item.setSymbol("+") + elif role == "label": + item.setSymbol("") + + if role == "user": + pass + elif role == "label": + item._setSelectable(False) + item._setDraggable(False) + item.setVisible(True) + else: + self.__updateEditable(item, self.isEditable(), remove=False) + item._setSelectable(False) + + self._handles.append((item, role)) + self.addItem(item) + return item + + def removeHandle(self, handle): + data = [d for d in self._handles if d[0] is handle][0] + self._handles.remove(data) + role = data[1] + if role not in ["user", "label"]: + if self.isEditable(): + self.__updateEditable(handle, False) + self.removeItem(handle) + + def getHandles(self): + """Returns the list of handles of this HandleBasedROI. + + :rtype: List[~silx.gui.plot.items.Marker] + """ + return tuple(data[0] for data in self._handles) + + def _updated(self, event=None, checkVisibility=True): + """Implement Item mix-in update method by updating the plot items + + See :class:`~silx.gui.plot.items.Item._updated` + """ + if event == items.ItemChangedType.NAME: + self._updateText(self.getName()) + elif event == items.ItemChangedType.VISIBLE: + for item, role in self._handles: + visible = self.isVisible() + editionVisible = visible and self.isEditable() + if role not in ["user", "label"]: + item.setVisible(editionVisible) + else: + item.setVisible(visible) + elif event == items.ItemChangedType.EDITABLE: + for item, role in self._handles: + editable = self.isEditable() + if role not in ["user", "label"]: + self.__updateEditable(item, editable) + super(HandleBasedROI, self)._updated(event, checkVisibility) + + def _updatedStyle(self, event, style): + super(HandleBasedROI, self)._updatedStyle(event, style) + + # Update color of shape items in the plot + color = rgba(self.getColor()) + handleColor = self._computeHandleColor(color) + for item, role in self._handles: + if role == 'user': + pass + elif role == 'label': + item.setColor(color) + else: + item.setColor(handleColor) + + def __updateEditable(self, handle, editable, remove=True): + # NOTE: visibility change emit a position update event + handle.setVisible(editable and self.isVisible()) + handle._setDraggable(editable) + if editable: + handle.sigDragStarted.connect(self._handleEditingStarted) + handle.sigItemChanged.connect(self._handleEditingUpdated) + handle.sigDragFinished.connect(self._handleEditingFinished) + else: + if remove: + handle.sigDragStarted.disconnect(self._handleEditingStarted) + handle.sigItemChanged.disconnect(self._handleEditingUpdated) + handle.sigDragFinished.disconnect(self._handleEditingFinished) + + def _handleEditingStarted(self): + super(HandleBasedROI, self)._editingStarted() + handle = self.sender() + self._posOrigin = numpy.array(handle.getPosition()) + self._posPrevious = numpy.array(self._posOrigin) + self.handleDragStarted(handle, self._posOrigin) + + def _handleEditingUpdated(self): + if self._posOrigin is None: + # Avoid to handle events when visibility change + return + handle = self.sender() + current = numpy.array(handle.getPosition()) + self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current) + self._posPrevious = current + + def _handleEditingFinished(self): + handle = self.sender() + current = numpy.array(handle.getPosition()) + self.handleDragFinished(handle, self._posOrigin, current) + self._posPrevious = None + self._posOrigin = None + super(HandleBasedROI, self)._editingFinished() + + def isHandleBeingDragged(self): + """Returns True if one of the handles is currently being dragged. + + :rtype: bool + """ + return self._posOrigin is not None + + def handleDragStarted(self, handle, origin): + """Called when an handler drag started""" + pass + + def handleDragUpdated(self, handle, origin, previous, current): + """Called when an handle drag position changed""" + pass + + def handleDragFinished(self, handle, origin, current): + """Called when an handle drag finished""" + pass + + def _computeHandleColor(self, color): + """Returns the anchor color from the base ROI color + + :param Union[numpy.array,Tuple,List]: color + :rtype: Union[numpy.array,Tuple,List] + """ + return color[:3] + (0.5,) + + def _updateText(self, text): + """Update the text displayed by this ROI + + :param str text: A text + """ + pass diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py index 8f0694d..0e492a0 100644 --- a/silx/gui/plot/items/complex.py +++ b/silx/gui/plot/items/complex.py @@ -124,10 +124,9 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): """Overrides supported ComplexMode""" def __init__(self): - ImageBase.__init__(self) + ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.complex64)) ColormapMixIn.__init__(self) ComplexMixIn.__init__(self) - self._data = numpy.zeros((0, 0), dtype=numpy.complex64) self._dataByModesCache = {} self._amplitudeRangeInfo = None, 2 @@ -264,17 +263,9 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): 'Image is not complex, converting it to complex to plot it.') data = numpy.array(data, dtype=numpy.complex64) - self._data = data self._dataByModesCache = {} self._setColormappedData(self.getData(copy=False), copy=False) - - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - self._updated(ItemChangedType.DATA) + super().setData(data) def getComplexData(self, copy=True): """Returns the image complex data @@ -283,7 +274,7 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn): False to use internal representation (do not modify!) :rtype: numpy.ndarray of complex """ - return numpy.array(self._data, copy=copy) + return super().getData(copy=copy) def getData(self, copy=True, mode=None): """Returns the image data corresponding to (current) mode. diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py index 9426a13..edc6d89 100644 --- a/silx/gui/plot/items/core.py +++ b/silx/gui/plot/items/core.py @@ -37,6 +37,7 @@ except ImportError: # Python2 support from copy import deepcopy import logging import enum +from typing import Optional, Tuple import warnings import weakref @@ -44,7 +45,9 @@ import numpy import six from ....utils.deprecation import deprecated +from ....utils.proxy import docstring from ....utils.enum import Enum as _Enum +from ....math.combo import min_max from ... import qt from ... import colors from ...colors import Colormap @@ -164,6 +167,13 @@ class Item(qt.QObject): See :class:`ItemChangedType` for flags description. """ + _sigVisibleBoundsChanged = qt.Signal() + """Signal emitted when the visible extent of the item in the plot has changed. + + This signal is emitted only if visible extent tracking is enabled + (see :meth:`_setVisibleBoundsTracking`). + """ + def __init__(self): qt.QObject.__init__(self) self._dirty = True @@ -176,6 +186,9 @@ class Item(qt.QObject): self._ylabel = None self.__name = '' + self.__visibleBoundsTracking = False + self.__previousVisibleBounds = None + self._backendRenderer = None def getPlot(self): @@ -194,7 +207,9 @@ class Item(qt.QObject): """ if plot is not None and self._plotRef is not None: raise RuntimeError('Trying to add a node at two places.') + self.__disconnectFromPlotWidget() self._plotRef = None if plot is None else weakref.ref(plot) + self.__connectToPlotWidget() self._updated() def getBounds(self): # TODO return a Bounds object rather than a tuple @@ -300,6 +315,97 @@ class Item(qt.QObject): info = deepcopy(info) self._info = info + def getVisibleBounds(self) -> Optional[Tuple[float,float,float,float]]: + """Returns visible bounds of the item bounding box in the plot area. + + :returns: + (xmin, xmax, ymin, ymax) in data coordinates of the visible area or + None if item is not visible in the plot area. + :rtype: Union[List[float],None] + """ + plot = self.getPlot() + bounds = self.getBounds() + if plot is None or bounds is None or not self.isVisible(): + return None + + xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits()) + ymin, ymax = numpy.clip( + bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits()) + + if xmin == xmax or ymin == ymax: # Outside the plot area + return None + else: + return xmin, xmax, ymin, ymax + + def _isVisibleBoundsTracking(self) -> bool: + """Returns True if visible bounds changes are tracked. + + When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes. + :rtype: bool + """ + return self.__visibleBoundsTracking + + def _setVisibleBoundsTracking(self, enable: bool) -> None: + """Set whether or not to track visible bounds changes. + + :param bool enable: + """ + if enable != self.__visibleBoundsTracking: + self.__disconnectFromPlotWidget() + self.__previousVisibleBounds = None + self.__visibleBoundsTracking = enable + self.__connectToPlotWidget() + + def __getYAxis(self) -> str: + """Returns current Y axis ('left' or 'right')""" + return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left' + + def __connectToPlotWidget(self) -> None: + """Connect to PlotWidget signals and install event filter""" + if not self._isVisibleBoundsTracking(): + return + + plot = self.getPlot() + if plot is not None: + for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())): + axis.sigLimitsChanged.connect(self._visibleBoundsChanged) + + plot.installEventFilter(self) + + self._visibleBoundsChanged() + + def __disconnectFromPlotWidget(self) -> None: + """Disconnect from PlotWidget signals and remove event filter""" + if not self._isVisibleBoundsTracking(): + return + + plot = self.getPlot() + if plot is not None: + for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())): + axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged) + + plot.removeEventFilter(self) + + def _visibleBoundsChanged(self, *args) -> None: + """Check if visible extent actually changed and emit signal""" + if not self._isVisibleBoundsTracking(): + return # No visible extent tracking + + plot = self.getPlot() + if plot is None or not plot.isVisible(): + return # No plot or plot not visible + + extent = self.getVisibleBounds() + if extent != self.__previousVisibleBounds: + self.__previousVisibleBounds = extent + self._sigVisibleBoundsChanged.emit() + + def eventFilter(self, watched, event): + """Event filter to handle PlotWidget show events""" + if watched is self.getPlot() and event.type() == qt.QEvent.Show: + self._visibleBoundsChanged() + return super().eventFilter(watched, event) + def _updated(self, event=None, checkVisibility=True): """Mark the item as dirty (i.e., needing update). @@ -375,6 +481,29 @@ class Item(qt.QObject): return PickingResult(self, indices) +class DataItem(Item): + """Item with a data extent in the plot""" + + def _boundsChanged(self, checkVisibility: bool=True) -> None: + """Call this method in subclass when data bounds has changed. + + :param bool checkVisibility: + """ + if not checkVisibility or self.isVisible(): + self._visibleBoundsChanged() + + # TODO hackish data range implementation + plot = self.getPlot() + if plot is not None: + plot._invalidateDataRange() + + @docstring(Item) + def setVisible(self, visible: bool): + if visible != self.isVisible(): + self._boundsChanged(checkVisibility=False) + super().setVisible(visible) + + # Mix-in classes ############################################################## class ItemMixInBase(object): @@ -836,6 +965,22 @@ class YAxisMixIn(ItemMixInBase): assert yaxis in ('left', 'right') if yaxis != self._yaxis: self._yaxis = yaxis + # Handle data extent changed for DataItem + if isinstance(self, DataItem): + self._boundsChanged() + + # Handle visible extent changed + if self._isVisibleBoundsTracking(): + # Switch Y axis signal connection + plot = self.getPlot() + if plot is not None: + previousYAxis = 'left' if self.getXAxis() == 'right' else 'right' + plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect( + self._visibleBoundsChanged) + plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect( + self._visibleBoundsChanged) + self._visibleBoundsChanged() + self._updated(ItemChangedType.YAXIS) @@ -1066,6 +1211,16 @@ class ScatterVisualizationMixIn(ItemMixInBase): Available reduction functions are: 'mean' (default), 'count', 'sum'. """ + DATA_BOUNDS_HINT = 'data_bounds_hint' + """The expected bounds of the data in data coordinates. + + A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)). + This provides a hint for the data ranges in both dimensions. + It is eventually enlarged with actually data ranges. + + WARNING: dimension 0 i.e., Y first. + """ + _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = { VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'), VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'), @@ -1191,7 +1346,7 @@ class ScatterVisualizationMixIn(ItemMixInBase): return self.getVisualizationParameter(parameter) -class PointsBase(Item, SymbolMixIn, AlphaMixIn): +class PointsBase(DataItem, SymbolMixIn, AlphaMixIn): """Base class for :class:`Curve` and :class:`Scatter`""" # note: _logFilterData must be overloaded if you overload # getData to change its signature @@ -1201,7 +1356,7 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn): on top of images.""" def __init__(self): - Item.__init__(self) + DataItem.__init__(self) SymbolMixIn.__init__(self) AlphaMixIn.__init__(self) self._x = () @@ -1244,18 +1399,18 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn): # expand errorbars to 2xN if error.size == 1: # Scalar error = numpy.full( - (2, len(value)), error, dtype=numpy.float) + (2, len(value)), error, dtype=numpy.float64) elif error.ndim == 1: # N array newError = numpy.empty((2, len(value)), - dtype=numpy.float) + dtype=numpy.float64) newError[0, :] = error newError[1, :] = error error = newError elif error.size == 2 * len(value): # 2xN array error = numpy.array( - error, copy=True, dtype=numpy.float) + error, copy=True, dtype=numpy.float64) else: _logger.error("Unhandled error array") @@ -1309,9 +1464,9 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn): if numpy.any(clipped): # copy to keep original array and convert to float - x = numpy.array(x, copy=True, dtype=numpy.float) + x = numpy.array(x, copy=True, dtype=numpy.float64) x[clipped] = numpy.nan - y = numpy.array(y, copy=True, dtype=numpy.float) + y = numpy.array(y, copy=True, dtype=numpy.float64) y[clipped] = numpy.nan if xPositive and xerror is not None: @@ -1347,15 +1502,11 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn): else: x, y, _xerror, _yerror = data - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=RuntimeWarning) - # Ignore All-NaN slice encountered - self._boundsCache[(xPositive, yPositive)] = ( - numpy.nanmin(x), - numpy.nanmax(x), - numpy.nanmin(y), - numpy.nanmax(y) - ) + xmin, xmax = min_max(x, finite=True) + ymin, ymax = min_max(y, finite=True) + self._boundsCache[(xPositive, yPositive)] = tuple([ + (bound if bound is not None else numpy.nan) + for bound in (xmin, xmax, ymin, ymax)]) return self._boundsCache[(xPositive, yPositive)] def _getCachedData(self): @@ -1477,11 +1628,7 @@ class PointsBase(Item, SymbolMixIn, AlphaMixIn): self._filteredCache = {} # Reset cached filtered data self._clippedCache = {} # Reset cached clipped bool array - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() + self._boundsChanged() self._updated(ItemChangedType.DATA) diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py index 7922fa1..75e7f01 100644 --- a/silx/gui/plot/items/curve.py +++ b/silx/gui/plot/items/curve.py @@ -185,15 +185,6 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, self._setBaseline(Curve._DEFAULT_BASELINE) - self.sigItemChanged.connect(self.__itemChanged) - - def __itemChanged(self, event): - if event == ItemChangedType.YAXIS: - # TODO hackish data range implementation - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - def _addBackendRenderer(self, backend): """Update backend renderer""" # Filter-out values <= 0 @@ -251,20 +242,6 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn, else: raise IndexError("Index out of range: %s", str(item)) - def setVisible(self, visible): - """Set visibility of item. - - :param bool visible: True to display it, False otherwise - """ - visible = bool(visible) - # TODO hackish data range implementation - if self.isVisible() != visible: - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - super(Curve, self).setVisible(visible) - @deprecated(replacement='Curve.getHighlightedStyle().getColor()', since_version='0.9.0') def getHighlightedColor(self): diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py index 935f8d5..5941cc6 100644 --- a/silx/gui/plot/items/histogram.py +++ b/silx/gui/plot/items/histogram.py @@ -38,7 +38,7 @@ try: except ImportError: # Python2 support import collections as abc -from .core import (Item, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn, +from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn, LineMixIn, YAxisMixIn, ItemChangedType) _logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ def _getHistogramCurve(histogram, edges): # TODO: Yerror, test log scale -class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, +class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn, LineMixIn, YAxisMixIn, BaselineMixIn): """Description of an histogram""" @@ -119,7 +119,7 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, _DEFAULT_BASELINE = None def __init__(self): - Item.__init__(self) + DataItem.__init__(self) AlphaMixIn.__init__(self) BaselineMixIn.__init__(self) ColorMixIn.__init__(self) @@ -157,8 +157,8 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, (x <= 0) if xPositive else False, (y <= 0) if yPositive else False) # Make a copy and replace negative points by NaN - x = numpy.array(x, dtype=numpy.float) - y = numpy.array(y, dtype=numpy.float) + x = numpy.array(x, dtype=numpy.float64) + y = numpy.array(y, dtype=numpy.float64) x[clipped] = numpy.nan y[clipped] = numpy.nan @@ -187,17 +187,17 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, yPositive = False if xPositive or yPositive: - values = numpy.array(values, copy=True, dtype=numpy.float) + values = numpy.array(values, copy=True, dtype=numpy.float64) if xPositive: # Replace edges <= 0 by NaN and corresponding values by NaN clipped_edges = (edges <= 0) - edges = numpy.array(edges, copy=True, dtype=numpy.float) + edges = numpy.array(edges, copy=True, dtype=numpy.float64) edges[clipped_edges] = numpy.nan clipped_values = numpy.logical_or(clipped_edges[:-1], clipped_edges[1:]) else: - clipped_values = numpy.zeros_like(values, dtype=numpy.bool) + clipped_values = numpy.zeros_like(values, dtype=bool) if yPositive: # Replace values <= 0 by NaN, do not modify edges @@ -219,19 +219,6 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, min(0, numpy.nanmin(values)), max(0, numpy.nanmax(values))) - def setVisible(self, visible): - """Set visibility of item. - - :param bool visible: True to display it, False otherwise - """ - visible = bool(visible) - # TODO hackish data range implementation - if self.isVisible() != visible: - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - super(Histogram, self).setVisible(visible) - def getValueData(self, copy=True): """The values of the histogram @@ -314,11 +301,7 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn, self._alignement = align self._setBaseline(baseline) - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - + self._boundsChanged() self._updated(ItemChangedType.DATA) def getAlignment(self): diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py index 91c051d..fda4245 100644 --- a/silx/gui/plot/items/image.py +++ b/silx/gui/plot/items/image.py @@ -40,7 +40,7 @@ import logging import numpy from ....utils.proxy import docstring -from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, +from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn, AlphaMixIn, ItemChangedType) @@ -87,15 +87,20 @@ def _convertImageToRgba32(image, copy=True): return numpy.array(image, copy=copy) -class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): - """Description of an image""" +class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn): + """Description of an image - def __init__(self): - Item.__init__(self) + :param numpy.ndarray data: Initial image data + """ + + def __init__(self, data=None): + DataItem.__init__(self) LabelsMixIn.__init__(self) DraggableMixIn.__init__(self) AlphaMixIn.__init__(self) - self._data = numpy.zeros((0, 0, 4), dtype=numpy.uint8) + if data is None: + data = numpy.zeros((0, 0, 4), dtype=numpy.uint8) + self._data = data self._origin = (0., 0.) self._scale = (1., 1.) @@ -129,19 +134,6 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): else: raise IndexError("Index out of range: %s" % str(item)) - def setVisible(self, visible): - """Set visibility of item. - - :param bool visible: True to display it, False otherwise - """ - visible = bool(visible) - # TODO hackish data range implementation - if self.isVisible() != visible: - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - super(ImageBase, self).setVisible(visible) - def _isPlotLinear(self, plot): """Return True if plot only uses linear scale for both of x and y axes.""" @@ -189,6 +181,15 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): """ return numpy.array(self._data, copy=copy) + def setData(self, data): + """Set the image data + + :param numpy.ndarray data: + """ + self._data = data + self._boundsChanged() + self._updated(ItemChangedType.DATA) + def getRgbaImageData(self, copy=True): """Get the displayed RGB(A) image @@ -215,13 +216,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): origin = float(origin), float(origin) if origin != self._origin: self._origin = origin - - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - + self._boundsChanged() self._updated(ItemChangedType.POSITION) def getScale(self): @@ -244,13 +239,7 @@ class ImageBase(Item, LabelsMixIn, DraggableMixIn, AlphaMixIn): if scale != self._scale: self._scale = scale - - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - + self._boundsChanged() self._updated(ItemChangedType.SCALE) @@ -258,9 +247,8 @@ class ImageData(ImageBase, ColormapMixIn): """Description of a data image with a colormap""" def __init__(self): - ImageBase.__init__(self) + ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32)) ColormapMixIn.__init__(self) - self._data = numpy.zeros((0, 0), dtype=numpy.float32) self._alternativeImage = None self.__alpha = None @@ -370,7 +358,6 @@ class ImageData(ImageBase, ColormapMixIn): _logger.warning( 'Converting complex image to absolute value to plot it.') data = numpy.absolute(data) - self._data = data self._setColormappedData(data, copy=False) if alternative is not None: @@ -389,20 +376,14 @@ class ImageData(ImageBase, ColormapMixIn): alpha = numpy.clip(alpha, 0., 1.) self.__alpha = alpha - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - self._updated(ItemChangedType.DATA) + super().setData(data) class ImageRgba(ImageBase): """Description of an RGB(A) image""" def __init__(self): - ImageBase.__init__(self) + ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8)) def _addBackendRenderer(self, backend): """Update backend renderer""" @@ -440,15 +421,7 @@ class ImageRgba(ImageBase): data = numpy.array(data, copy=copy) assert data.ndim == 3 assert data.shape[-1] in (3, 4) - self._data = data - - # TODO hackish data range implementation - if self.isVisible(): - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - self._updated(ItemChangedType.DATA) + super().setData(data) class MaskImageData(ImageData): diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py index ff73fe6..38a1424 100644 --- a/silx/gui/plot/items/roi.py +++ b/silx/gui/plot/items/roi.py @@ -36,729 +36,25 @@ __date__ = "28/06/2018" import logging import numpy -import weakref -from silx.image.shapes import Polygon -from ....utils.weakref import WeakList -from ... import qt from ... import utils from .. import items -from ..items import core from ...colors import rgba -import silx.utils.deprecation +from silx.image.shapes import Polygon from silx.image._boundingbox import _BoundingBox from ....utils.proxy import docstring from ..utils.intersections import segments_intersection +from ._roi_base import _RegionOfInterestBase +# He following imports have to be exposed by this module +from ._roi_base import RegionOfInterest +from ._roi_base import HandleBasedROI +from ._arc_roi import ArcROI # noqa +from ._roi_base import InteractionModeMixIn # noqa +from ._roi_base import RoiInteractionMode # noqa -logger = logging.getLogger(__name__) - - -class _RegionOfInterestBase(qt.QObject): - """Base class of 1D and 2D region of interest - - :param QObject parent: See QObject - :param str name: The name of the ROI - """ - - sigAboutToBeRemoved = qt.Signal() - """Signal emitted just before this ROI is removed from its manager.""" - - sigItemChanged = qt.Signal(object) - """Signal emitted when item has changed. - - It provides a flag describing which property of the item has changed. - See :class:`ItemChangedType` for flags description. - """ - - def __init__(self, parent=None): - qt.QObject.__init__(self, parent=parent) - self.__name = '' - - def getName(self): - """Returns the name of the ROI - - :return: name of the region of interest - :rtype: str - """ - return self.__name - - def setName(self, name): - """Set the name of the ROI - - :param str name: name of the region of interest - """ - name = str(name) - if self.__name != name: - self.__name = name - self._updated(items.ItemChangedType.NAME) - - def _updated(self, event=None, checkVisibility=True): - """Implement Item mix-in update method by updating the plot items - - See :class:`~silx.gui.plot.items.Item._updated` - """ - self.sigItemChanged.emit(event) - - def contains(self, position): - """Returns True if the `position` is in this ROI. - - :param tuple[float,float] position: position to check - :return: True if the value / point is consider to be in the region of - interest. - :rtype: bool - """ - raise NotImplementedError("Base class") - - -class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn): - """Object describing a region of interest in a plot. - - :param QObject parent: - The RegionOfInterestManager that created this object - """ - - _DEFAULT_LINEWIDTH = 1. - """Default line width of the curve""" - - _DEFAULT_LINESTYLE = '-' - """Default line style of the curve""" - - _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2) - """Default highlight style of the item""" - - ICON, NAME, SHORT_NAME = None, None, None - """Metadata to describe the ROI in labels, tooltips and widgets - - Should be set by inherited classes to custom the ROI manager widget. - """ - - sigRegionChanged = qt.Signal() - """Signal emitted everytime the shape or position of the ROI changes""" - - sigEditingStarted = qt.Signal() - """Signal emitted when the user start editing the roi""" - - sigEditingFinished = qt.Signal() - """Signal emitted when the region edition is finished. During edition - sigEditionChanged will be emitted several times and - sigRegionEditionFinished only at end""" - - def __init__(self, parent=None): - # Avoid circular dependency - from ..tools import roi as roi_tools - assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager) - _RegionOfInterestBase.__init__(self, parent) - core.HighlightedMixIn.__init__(self) - self._color = rgba('red') - self._editable = False - self._selectable = False - self._focusProxy = None - self._visible = True - self._child = WeakList() - - def _connectToPlot(self, plot): - """Called after connection to a plot""" - for item in self.getItems(): - # This hack is needed to avoid reentrant call from _disconnectFromPlot - # to the ROI manager. It also speed up the item tests in _itemRemoved - item._roiGroup = True - plot.addItem(item) - - def _disconnectFromPlot(self, plot): - """Called before disconnection from a plot""" - for item in self.getItems(): - # The item could be already be removed by the plot - if item.getPlot() is not None: - del item._roiGroup - plot.removeItem(item) - - def _setItemName(self, item): - """Helper to generate a unique id to a plot item""" - legend = "__ROI-%d__%d" % (id(self), id(item)) - item.setName(legend) - - def setParent(self, parent): - """Set the parent of the RegionOfInterest - - :param Union[None,RegionOfInterestManager] parent: The new parent - """ - # Avoid circular dependency - from ..tools import roi as roi_tools - if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)): - raise ValueError('Unsupported parent') - - previousParent = self.parent() - if previousParent is not None: - previousPlot = previousParent.parent() - if previousPlot is not None: - self._disconnectFromPlot(previousPlot) - super(RegionOfInterest, self).setParent(parent) - if parent is not None: - plot = parent.parent() - if plot is not None: - self._connectToPlot(plot) - - def addItem(self, item): - """Add an item to the set of this ROI children. - - This item will be added and removed to the plot used by the ROI. - - If the ROI is already part of a plot, the item will also be added to - the plot. - - It the item do not have a name already, a unique one is generated to - avoid item collision in the plot. - - :param silx.gui.plot.items.Item item: A plot item - """ - assert item is not None - self._child.append(item) - if item.getName() == '': - self._setItemName(item) - manager = self.parent() - if manager is not None: - plot = manager.parent() - if plot is not None: - item._roiGroup = True - plot.addItem(item) - - def removeItem(self, item): - """Remove an item from this ROI children. - - If the item is part of a plot it will be removed too. - - :param silx.gui.plot.items.Item item: A plot item - """ - assert item is not None - self._child.remove(item) - plot = item.getPlot() - if plot is not None: - del item._roiGroup - plot.removeItem(item) - - def getItems(self): - """Returns the list of PlotWidget items of this RegionOfInterest. - - :rtype: List[~silx.gui.plot.items.Item] - """ - return tuple(self._child) - - @classmethod - def _getShortName(cls): - """Return an human readable kind of ROI - - :rtype: str - """ - if hasattr(cls, "SHORT_NAME"): - name = cls.SHORT_NAME - if name is None: - name = cls.__name__ - return name - - def getColor(self): - """Returns the color of this ROI - - :rtype: QColor - """ - return qt.QColor.fromRgbF(*self._color) - - def setColor(self, color): - """Set the color used for this ROI. - - :param color: The color to use for ROI shape as - either a color name, a QColor, a list of uint8 or float in [0, 1]. - """ - color = rgba(color) - if color != self._color: - self._color = color - self._updated(items.ItemChangedType.COLOR) - - @silx.utils.deprecation.deprecated(reason='API modification', - replacement='getName()', - since_version=0.12) - def getLabel(self): - """Returns the label displayed for this ROI. - - :rtype: str - """ - return self.getName() - - @silx.utils.deprecation.deprecated(reason='API modification', - replacement='setName(name)', - since_version=0.12) - def setLabel(self, label): - """Set the label displayed with this ROI. - - :param str label: The text label to display - """ - self.setName(name=label) - - def isEditable(self): - """Returns whether the ROI is editable by the user or not. - - :rtype: bool - """ - return self._editable - - def setEditable(self, editable): - """Set whether the ROI can be changed interactively. - - :param bool editable: True to allow edition by the user, - False to disable. - """ - editable = bool(editable) - if self._editable != editable: - self._editable = editable - self._updated(items.ItemChangedType.EDITABLE) - - def isSelectable(self): - """Returns whether the ROI is selectable by the user or not. - - :rtype: bool - """ - return self._selectable - - def setSelectable(self, selectable): - """Set whether the ROI can be selected interactively. - - :param bool selectable: True to allow selection by the user, - False to disable. - """ - selectable = bool(selectable) - if self._selectable != selectable: - self._selectable = selectable - self._updated(items.ItemChangedType.SELECTABLE) - - def getFocusProxy(self): - """Returns the ROI which have to be selected when this ROI is selected, - else None if no proxy specified. - - :rtype: RegionOfInterest - """ - proxy = self._focusProxy - if proxy is None: - return None - proxy = proxy() - if proxy is None: - self._focusProxy = None - return proxy - - def setFocusProxy(self, roi): - """Set the real ROI which will be selected when this ROI is selected, - else None to remove the proxy already specified. - - :param RegionOfInterest roi: A ROI - """ - if roi is not None: - self._focusProxy = weakref.ref(roi) - else: - self._focusProxy = None - - def isVisible(self): - """Returns whether the ROI is visible in the plot. - - .. note:: - This does not take into account whether or not the plot - widget itself is visible (unlike :meth:`QWidget.isVisible` which - checks the visibility of all its parent widgets up to the window) - - :rtype: bool - """ - return self._visible - - def setVisible(self, visible): - """Set whether the plot items associated with this ROI are - visible in the plot. - - :param bool visible: True to show the ROI in the plot, False to - hide it. - """ - visible = bool(visible) - if self._visible != visible: - self._visible = visible - self._updated(items.ItemChangedType.VISIBLE) - - @classmethod - def showFirstInteractionShape(cls): - """Returns True if the shape created by the first interaction and - managed by the plot have to be visible. - - :rtype: bool - """ - return False - - @classmethod - def getFirstInteractionShape(cls): - """Returns the shape kind which will be used by the very first - interaction with the plot. - - This interactions are hardcoded inside the plot - - :rtype: str - """ - return cls._plotShape - - def setFirstShapePoints(self, points): - """"Initialize the ROI using the points from the first interaction. - - This interaction is constrained by the plot API and only supports few - shapes. - """ - raise NotImplementedError() - - def creationStarted(self): - """"Called when the ROI creation interaction was started. - """ - pass - - @docstring(_RegionOfInterestBase) - def contains(self, position): - raise NotImplementedError("Base class") - - def creationFinalized(self): - """"Called when the ROI creation interaction was finalized. - """ - pass - - def _updateItemProperty(self, event, source, destination): - """Update the item property of a destination from an item source. - - :param items.ItemChangedType event: Property type to update - :param silx.gui.plot.items.Item source: The reference for the data - :param event Union[Item,List[Item]] destination: The item(s) to update - """ - if not isinstance(destination, (list, tuple)): - destination = [destination] - if event == items.ItemChangedType.NAME: - value = source.getName() - for d in destination: - d.setName(value) - elif event == items.ItemChangedType.EDITABLE: - value = source.isEditable() - for d in destination: - d.setEditable(value) - elif event == items.ItemChangedType.SELECTABLE: - value = source.isSelectable() - for d in destination: - d._setSelectable(value) - elif event == items.ItemChangedType.COLOR: - value = rgba(source.getColor()) - for d in destination: - d.setColor(value) - elif event == items.ItemChangedType.LINE_STYLE: - value = self.getLineStyle() - for d in destination: - d.setLineStyle(value) - elif event == items.ItemChangedType.LINE_WIDTH: - value = self.getLineWidth() - for d in destination: - d.setLineWidth(value) - elif event == items.ItemChangedType.SYMBOL: - value = self.getSymbol() - for d in destination: - d.setSymbol(value) - elif event == items.ItemChangedType.SYMBOL_SIZE: - value = self.getSymbolSize() - for d in destination: - d.setSymbolSize(value) - elif event == items.ItemChangedType.VISIBLE: - value = self.isVisible() - for d in destination: - d.setVisible(value) - else: - assert False - - def _updated(self, event=None, checkVisibility=True): - if event == items.ItemChangedType.HIGHLIGHTED: - style = self.getCurrentStyle() - self._updatedStyle(event, style) - else: - hilighted = self.isHighlighted() - if hilighted: - if event == items.ItemChangedType.HIGHLIGHTED_STYLE: - style = self.getCurrentStyle() - self._updatedStyle(event, style) - else: - if event in [items.ItemChangedType.COLOR, - items.ItemChangedType.LINE_STYLE, - items.ItemChangedType.LINE_WIDTH, - items.ItemChangedType.SYMBOL, - items.ItemChangedType.SYMBOL_SIZE]: - style = self.getCurrentStyle() - self._updatedStyle(event, style) - super(RegionOfInterest, self)._updated(event, checkVisibility) - - def _updatedStyle(self, event, style): - """Called when the current displayed style of the ROI was changed. - - :param event: The event responsible of the change of the style - :param items.CurveStyle style: The current style - """ - pass - - def getCurrentStyle(self): - """Returns the current curve style. - - Curve style depends on curve highlighting - - :rtype: CurveStyle - """ - baseColor = rgba(self.getColor()) - if isinstance(self, core.LineMixIn): - baseLinestyle = self.getLineStyle() - baseLinewidth = self.getLineWidth() - else: - baseLinestyle = self._DEFAULT_LINESTYLE - baseLinewidth = self._DEFAULT_LINEWIDTH - if isinstance(self, core.SymbolMixIn): - baseSymbol = self.getSymbol() - baseSymbolsize = self.getSymbolSize() - else: - baseSymbol = 'o' - baseSymbolsize = 1 - - if self.isHighlighted(): - style = self.getHighlightedStyle() - color = style.getColor() - linestyle = style.getLineStyle() - linewidth = style.getLineWidth() - symbol = style.getSymbol() - symbolsize = style.getSymbolSize() - - return items.CurveStyle( - color=baseColor if color is None else color, - linestyle=baseLinestyle if linestyle is None else linestyle, - linewidth=baseLinewidth if linewidth is None else linewidth, - symbol=baseSymbol if symbol is None else symbol, - symbolsize=baseSymbolsize if symbolsize is None else symbolsize) - else: - return items.CurveStyle(color=baseColor, - linestyle=baseLinestyle, - linewidth=baseLinewidth, - symbol=baseSymbol, - symbolsize=baseSymbolsize) - - def _editingStarted(self): - assert self._editable is True - self.sigEditingStarted.emit() - - def _editingFinished(self): - self.sigEditingFinished.emit() - - -class HandleBasedROI(RegionOfInterest): - """Manage a ROI based on a set of handles""" - - def __init__(self, parent=None): - RegionOfInterest.__init__(self, parent=parent) - self._handles = [] - self._posOrigin = None - self._posPrevious = None - - def addUserHandle(self, item=None): - """ - Add a new free handle to the ROI. - - This handle do nothing. It have to be managed by the ROI - implementing this class. - - :param Union[None,silx.gui.plot.items.Marker] item: The new marker to - add, else None to create a default marker. - :rtype: silx.gui.plot.items.Marker - """ - return self.addHandle(item, role="user") - - def addLabelHandle(self, item=None): - """ - Add a new label handle to the ROI. - - This handle is not draggable nor selectable. - - It is displayed without symbol, but it is always visible anyway - the ROI is editable, in order to display text. - - :param Union[None,silx.gui.plot.items.Marker] item: The new marker to - add, else None to create a default marker. - :rtype: silx.gui.plot.items.Marker - """ - return self.addHandle(item, role="label") - - def addTranslateHandle(self, item=None): - """ - Add a new translate handle to the ROI. - - Dragging translate handles affect the position position of the ROI - but not the shape itself. - - :param Union[None,silx.gui.plot.items.Marker] item: The new marker to - add, else None to create a default marker. - :rtype: silx.gui.plot.items.Marker - """ - return self.addHandle(item, role="translate") - - def addHandle(self, item=None, role="default"): - """ - Add a new handle to the ROI. - - Dragging handles while affect the position or the shape of the - ROI. - - :param Union[None,silx.gui.plot.items.Marker] item: The new marker to - add, else None to create a default marker. - :rtype: silx.gui.plot.items.Marker - """ - if item is None: - item = items.Marker() - color = rgba(self.getColor()) - color = self._computeHandleColor(color) - item.setColor(color) - if role == "default": - item.setSymbol("s") - elif role == "user": - pass - elif role == "translate": - item.setSymbol("+") - elif role == "label": - item.setSymbol("") - - if role == "user": - pass - elif role == "label": - item._setSelectable(False) - item._setDraggable(False) - item.setVisible(True) - else: - self.__updateEditable(item, self.isEditable(), remove=False) - item._setSelectable(False) - - self._handles.append((item, role)) - self.addItem(item) - return item - - def removeHandle(self, handle): - data = [d for d in self._handles if d[0] is handle][0] - self._handles.remove(data) - role = data[1] - if role not in ["user", "label"]: - if self.isEditable(): - self.__updateEditable(handle, False) - self.removeItem(handle) - - def getHandles(self): - """Returns the list of handles of this HandleBasedROI. - - :rtype: List[~silx.gui.plot.items.Marker] - """ - return tuple(data[0] for data in self._handles) - - def _updated(self, event=None, checkVisibility=True): - """Implement Item mix-in update method by updating the plot items - - See :class:`~silx.gui.plot.items.Item._updated` - """ - if event == items.ItemChangedType.NAME: - self._updateText(self.getName()) - elif event == items.ItemChangedType.VISIBLE: - for item, role in self._handles: - visible = self.isVisible() - editionVisible = visible and self.isEditable() - if role not in ["user", "label"]: - item.setVisible(editionVisible) - else: - item.setVisible(visible) - elif event == items.ItemChangedType.EDITABLE: - for item, role in self._handles: - editable = self.isEditable() - if role not in ["user", "label"]: - self.__updateEditable(item, editable) - super(HandleBasedROI, self)._updated(event, checkVisibility) - - def _updatedStyle(self, event, style): - super(HandleBasedROI, self)._updatedStyle(event, style) - - # Update color of shape items in the plot - color = rgba(self.getColor()) - handleColor = self._computeHandleColor(color) - for item, role in self._handles: - if role == 'user': - pass - elif role == 'label': - item.setColor(color) - else: - item.setColor(handleColor) - - def __updateEditable(self, handle, editable, remove=True): - # NOTE: visibility change emit a position update event - handle.setVisible(editable and self.isVisible()) - handle._setDraggable(editable) - if editable: - handle.sigDragStarted.connect(self._handleEditingStarted) - handle.sigItemChanged.connect(self._handleEditingUpdated) - handle.sigDragFinished.connect(self._handleEditingFinished) - else: - if remove: - handle.sigDragStarted.disconnect(self._handleEditingStarted) - handle.sigItemChanged.disconnect(self._handleEditingUpdated) - handle.sigDragFinished.disconnect(self._handleEditingFinished) - - def _handleEditingStarted(self): - super(HandleBasedROI, self)._editingStarted() - handle = self.sender() - self._posOrigin = numpy.array(handle.getPosition()) - self._posPrevious = numpy.array(self._posOrigin) - self.handleDragStarted(handle, self._posOrigin) - - def _handleEditingUpdated(self): - if self._posOrigin is None: - # Avoid to handle events when visibility change - return - handle = self.sender() - current = numpy.array(handle.getPosition()) - self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current) - self._posPrevious = current - - def _handleEditingFinished(self): - handle = self.sender() - current = numpy.array(handle.getPosition()) - self.handleDragFinished(handle, self._posOrigin, current) - self._posPrevious = None - self._posOrigin = None - super(HandleBasedROI, self)._editingFinished() - - def isHandleBeingDragged(self): - """Returns True if one of the handles is currently being dragged. - - :rtype: bool - """ - return self._posOrigin is not None - - def handleDragStarted(self, handle, origin): - """Called when an handler drag started""" - pass - - def handleDragUpdated(self, handle, origin, previous, current): - """Called when an handle drag position changed""" - pass - - def handleDragFinished(self, handle, origin, current): - """Called when an handle drag finished""" - pass - - def _computeHandleColor(self, color): - """Returns the anchor color from the base ROI color - :param Union[numpy.array,Tuple,List]: color - :rtype: Union[numpy.array,Tuple,List] - """ - return color[:3] + (0.5,) - - def _updateText(self, text): - """Update the text displayed by this ROI - - :param str text: A text - """ - pass +logger = logging.getLogger(__name__) class PointROI(RegionOfInterest, items.SymbolMixIn): @@ -821,7 +117,8 @@ class PointROI(RegionOfInterest, items.SymbolMixIn): @docstring(_RegionOfInterestBase) def contains(self, position): - raise NotImplementedError('Base class') + roiPos = self.getPosition() + return position[0] == roiPos[0] and position[1] == roiPos[1] def _pointPositionChanged(self, event): """Handle position changed events of the marker""" @@ -1022,11 +319,12 @@ class LineROI(HandleBasedROI, items.LineMixIn): top_left = position[0], position[1] + 1 top_right = position[0] + 1, position[1] + 1 - line_pt1 = self._points[0] - line_pt2 = self._points[1] + points = self.__shape.getPoints() + line_pt1 = points[0] + line_pt2 = points[1] - bb1 = _BoundingBox.from_points(self._points) - if bb1.contains(position) is False: + bb1 = _BoundingBox.from_points(points) + if not bb1.contains(position): return False return ( @@ -1038,7 +336,7 @@ class LineROI(HandleBasedROI, items.LineMixIn): seg2_start_pt=top_right, seg2_end_pt=top_left) or segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2, seg2_start_pt=top_left, seg2_end_pt=bottom_left) - ) + ) is not None def __str__(self): start, end = self.getEndPoints() @@ -1106,7 +404,7 @@ class HorizontalLineROI(RegionOfInterest, items.LineMixIn): @docstring(_RegionOfInterestBase) def contains(self, position): - return position[1] == self.getPosition()[1] + return position[1] == self.getPosition() def _linePositionChanged(self, event): """Handle position changed events of the marker""" @@ -1175,7 +473,7 @@ class VerticalLineROI(RegionOfInterest, items.LineMixIn): @docstring(RegionOfInterest) def contains(self, position): - return position[0] == self.getPosition()[0] + return position[0] == self.getPosition() def _linePositionChanged(self, event): """Handle position changed events of the marker""" @@ -1515,6 +813,10 @@ class CircleROI(HandleBasedROI, items.LineMixIn): center = self.getCenter() self.setRadius(numpy.linalg.norm(center - current)) + @docstring(HandleBasedROI) + def contains(self, position): + return numpy.linalg.norm(self.getCenter() - position) <= self.getRadius() + def __str__(self): center = self.getCenter() radius = self.getRadius() @@ -1726,7 +1028,7 @@ class EllipseROI(HandleBasedROI, items.LineMixIn): orientation = self.getOrientation() if self._radius[1] > self._radius[0]: # _handleAxis1 is the major axis - orientation -= numpy.pi/2 + orientation -= numpy.pi / 2 point0 = numpy.array([center[0] + self._radius[0] * numpy.cos(orientation), center[1] + self._radius[0] * numpy.sin(orientation)]) @@ -1760,13 +1062,13 @@ class EllipseROI(HandleBasedROI, items.LineMixIn): if handle is self._handleAxis1: if self._radius[0] > distance: # _handleAxis1 is not the major axis, rotate -90 degrees - orientation -= numpy.pi/2 + orientation -= numpy.pi / 2 radius = self._radius[0], distance else: # _handleAxis0 if self._radius[1] > distance: # _handleAxis0 is not the major axis, rotate +90 degrees - orientation += numpy.pi/2 + orientation += numpy.pi / 2 radius = distance, self._radius[1] self.setGeometry(radius=radius, orientation=orientation) @@ -1776,6 +1078,14 @@ class EllipseROI(HandleBasedROI, items.LineMixIn): if event is items.ItemChangedType.POSITION: self._updateGeometry() + @docstring(HandleBasedROI) + def contains(self, position): + major, minor = self.getMajorRadius(), self.getMinorRadius() + delta = self.getOrientation() + x, y = position - self.getCenter() + return ((x*numpy.cos(delta) + y*numpy.sin(delta))**2/major**2 + + (x*numpy.sin(delta) - y*numpy.cos(delta))**2/minor**2) <= 1 + def __str__(self): center = self.getCenter() major = self.getMajorRadius() @@ -1987,682 +1297,6 @@ class PolygonROI(HandleBasedROI, items.LineMixIn): self._polygon_shape = None -class ArcROI(HandleBasedROI, items.LineMixIn): - """A ROI identifying an arc of a circle with a width. - - This ROI provides - - 3 handle to control the curvature - - 1 handle to control the weight - - 1 anchor to translate the shape. - """ - - ICON = 'add-shape-arc' - NAME = 'arc ROI' - SHORT_NAME = "arc" - """Metadata for this kind of ROI""" - - _plotShape = "line" - """Plot shape which is used for the first interaction""" - - class _Geometry: - def __init__(self): - self.center = None - self.startPoint = None - self.endPoint = None - self.radius = None - self.weight = None - self.startAngle = None - self.endAngle = None - self._closed = None - - @classmethod - def createEmpty(cls): - zero = numpy.array([0, 0]) - return cls.create(zero, zero.copy(), zero.copy(), 0, 0, 0, 0) - - @classmethod - def createRect(cls, startPoint, endPoint, weight): - return cls.create(None, startPoint, endPoint, None, weight, None, None, False) - - @classmethod - def createCircle(cls, center, startPoint, endPoint, radius, - weight, startAngle, endAngle): - return cls.create(center, startPoint, endPoint, radius, - weight, startAngle, endAngle, True) - - @classmethod - def create(cls, center, startPoint, endPoint, radius, - weight, startAngle, endAngle, closed=False): - g = cls() - g.center = center - g.startPoint = startPoint - g.endPoint = endPoint - g.radius = radius - g.weight = weight - g.startAngle = startAngle - g.endAngle = endAngle - g._closed = closed - return g - - def withWeight(self, weight): - """Create a new geometry with another weight - """ - return self.create(self.center, self.startPoint, self.endPoint, - self.radius, weight, - self.startAngle, self.endAngle, self._closed) - - def withRadius(self, radius): - """Create a new geometry with another radius. - - The weight and the center is conserved. - """ - startPoint = self.center + (self.startPoint - self.center) / self.radius * radius - endPoint = self.center + (self.endPoint - self.center) / self.radius * radius - return self.create(self.center, startPoint, endPoint, - radius, self.weight, - self.startAngle, self.endAngle, self._closed) - - def translated(self, x, y): - delta = numpy.array([x, y]) - center = None if self.center is None else self.center + delta - startPoint = None if self.startPoint is None else self.startPoint + delta - endPoint = None if self.endPoint is None else self.endPoint + delta - return self.create(center, startPoint, endPoint, - self.radius, self.weight, - self.startAngle, self.endAngle, self._closed) - - def getKind(self): - """Returns the kind of shape defined""" - if self.center is None: - return "rect" - elif numpy.isnan(self.startAngle): - return "point" - elif self.isClosed(): - if self.weight <= 0 or self.weight * 0.5 >= self.radius: - return "circle" - else: - return "donut" - else: - if self.weight * 0.5 < self.radius: - return "arc" - else: - return "camembert" - - def isClosed(self): - """Returns True if the geometry is a circle like""" - if self._closed is not None: - return self._closed - delta = numpy.abs(self.endAngle - self.startAngle) - self._closed = numpy.isclose(delta, numpy.pi * 2) - return self._closed - - def __str__(self): - return str((self.center, - self.startPoint, - self.endPoint, - self.radius, - self.weight, - self.startAngle, - self.endAngle, - self._closed)) - - def __init__(self, parent=None): - HandleBasedROI.__init__(self, parent=parent) - items.LineMixIn.__init__(self) - self._geometry = self._Geometry.createEmpty() - self._handleLabel = self.addLabelHandle() - - self._handleStart = self.addHandle() - self._handleStart.setSymbol("o") - self._handleMid = self.addHandle() - self._handleMid.setSymbol("o") - self._handleEnd = self.addHandle() - self._handleEnd.setSymbol("o") - self._handleWeight = self.addHandle() - self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint) - self._handleMove = self.addTranslateHandle() - - shape = items.Shape("polygon") - shape.setPoints([[0, 0], [0, 0]]) - shape.setColor(rgba(self.getColor())) - shape.setFill(False) - shape.setOverlay(True) - shape.setLineStyle(self.getLineStyle()) - shape.setLineWidth(self.getLineWidth()) - self.__shape = shape - self.addItem(shape) - - def _updated(self, event=None, checkVisibility=True): - if event == items.ItemChangedType.VISIBLE: - self._updateItemProperty(event, self, self.__shape) - super(ArcROI, self)._updated(event, checkVisibility) - - def _updatedStyle(self, event, style): - super(ArcROI, self)._updatedStyle(event, style) - self.__shape.setColor(style.getColor()) - self.__shape.setLineStyle(style.getLineStyle()) - self.__shape.setLineWidth(style.getLineWidth()) - - def setFirstShapePoints(self, points): - """"Initialize the ROI using the points from the first interaction. - - This interaction is constrained by the plot API and only supports few - shapes. - """ - # The first shape is a line - point0 = points[0] - point1 = points[1] - - # Compute a non collinear point for the curvature - center = (point1 + point0) * 0.5 - normal = point1 - center - normal = numpy.array((normal[1], -normal[0])) - defaultCurvature = numpy.pi / 5.0 - weightCoef = 0.20 - mid = center - normal * defaultCurvature - distance = numpy.linalg.norm(point0 - point1) - weight = distance * weightCoef - - geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight) - self._geometry = geometry - self._updateHandles() - - def _updateText(self, text): - self._handleLabel.setText(text) - - def _updateMidHandle(self): - """Keep the same geometry, but update the location of the control - points. - - So calling this function do not trigger sigRegionChanged. - """ - geometry = self._geometry - - if geometry.isClosed(): - start = numpy.array(self._handleStart.getPosition()) - geometry.endPoint = start - with utils.blockSignals(self._handleEnd): - self._handleEnd.setPosition(*start) - midPos = geometry.center + geometry.center - start - else: - if geometry.center is None: - midPos = geometry.startPoint * 0.66 + geometry.endPoint * 0.34 - else: - midAngle = geometry.startAngle * 0.66 + geometry.endAngle * 0.34 - vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) - midPos = geometry.center + geometry.radius * vector - - with utils.blockSignals(self._handleMid): - self._handleMid.setPosition(*midPos) - - def _updateWeightHandle(self): - geometry = self._geometry - if geometry.center is None: - # rectangle - center = (geometry.startPoint + geometry.endPoint) * 0.5 - normal = geometry.endPoint - geometry.startPoint - normal = numpy.array((normal[1], -normal[0])) - distance = numpy.linalg.norm(normal) - if distance != 0: - normal = normal / distance - weightPos = center + normal * geometry.weight * 0.5 - else: - if geometry.isClosed(): - midAngle = geometry.startAngle + numpy.pi * 0.5 - elif geometry.center is not None: - midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 - vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) - weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector - - with utils.blockSignals(self._handleWeight): - self._handleWeight.setPosition(*weightPos) - - def _getWeightFromHandle(self, weightPos): - geometry = self._geometry - if geometry.center is None: - # rectangle - center = (geometry.startPoint + geometry.endPoint) * 0.5 - return numpy.linalg.norm(center - weightPos) * 2 - else: - distance = numpy.linalg.norm(geometry.center - weightPos) - return abs(distance - geometry.radius) * 2 - - def _updateHandles(self): - geometry = self._geometry - with utils.blockSignals(self._handleStart): - self._handleStart.setPosition(*geometry.startPoint) - with utils.blockSignals(self._handleEnd): - self._handleEnd.setPosition(*geometry.endPoint) - - self._updateMidHandle() - self._updateWeightHandle() - - self._updateShape() - - def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False): - """Update the curvature using 3 control points in the curve - - :param bool updateCurveHandles: If False curve handles are already at - the right location - """ - if updateCurveHandles: - with utils.blockSignals(self._handleStart): - self._handleStart.setPosition(*start) - with utils.blockSignals(self._handleMid): - self._handleMid.setPosition(*mid) - with utils.blockSignals(self._handleEnd): - self._handleEnd.setPosition(*end) - - if checkClosed: - closed = self._isCloseInPixel(start, end) - else: - closed = self._geometry.isClosed() - - weight = self._geometry.weight - geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed) - self._geometry = geometry - - self._updateWeightHandle() - self._updateShape() - - def handleDragUpdated(self, handle, origin, previous, current): - if handle is self._handleStart: - mid = numpy.array(self._handleMid.getPosition()) - end = numpy.array(self._handleEnd.getPosition()) - self._updateCurvature(current, mid, end, - checkClosed=True, updateCurveHandles=False) - elif handle is self._handleMid: - if self._geometry.isClosed(): - radius = numpy.linalg.norm(self._geometry.center - current) - self._geometry = self._geometry.withRadius(radius) - self._updateHandles() - else: - start = numpy.array(self._handleStart.getPosition()) - end = numpy.array(self._handleEnd.getPosition()) - self._updateCurvature(start, current, end, updateCurveHandles=False) - elif handle is self._handleEnd: - start = numpy.array(self._handleStart.getPosition()) - mid = numpy.array(self._handleMid.getPosition()) - self._updateCurvature(start, mid, current, - checkClosed=True, updateCurveHandles=False) - elif handle is self._handleWeight: - weight = self._getWeightFromHandle(current) - self._geometry = self._geometry.withWeight(weight) - self._updateShape() - elif handle is self._handleMove: - delta = current - previous - self.translate(*delta) - - def _isCloseInPixel(self, point1, point2): - manager = self.parent() - if manager is None: - return False - plot = manager.parent() - if plot is None: - return False - point1 = plot.dataToPixel(*point1) - if point1 is None: - return False - point2 = plot.dataToPixel(*point2) - if point2 is None: - return False - return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15 - - def _normalizeGeometry(self): - """Keep the same phisical geometry, but with normalized parameters. - """ - geometry = self._geometry - if geometry.weight * 0.5 >= geometry.radius: - radius = (geometry.weight * 0.5 + geometry.radius) * 0.5 - geometry = geometry.withRadius(radius) - geometry = geometry.withWeight(radius * 2) - self._geometry = geometry - return True - return False - - def handleDragFinished(self, handle, origin, current): - if handle in [self._handleStart, self._handleMid, self._handleEnd]: - if self._normalizeGeometry(): - self._updateHandles() - else: - self._updateMidHandle() - if self._geometry.isClosed(): - self._handleStart.setSymbol("x") - self._handleEnd.setSymbol("x") - else: - self._handleStart.setSymbol("o") - self._handleEnd.setSymbol("o") - - def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None): - """Returns the geometry of the object""" - if closed or (closed is None and numpy.allclose(start, end)): - # Special arc: It's a closed circle - center = (start + mid) * 0.5 - radius = numpy.linalg.norm(start - center) - v = start - center - startAngle = numpy.angle(complex(v[0], v[1])) - endAngle = startAngle + numpy.pi * 2.0 - return self._Geometry.createCircle(center, start, end, radius, - weight, startAngle, endAngle) - - elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5: - # Degenerated arc, it's a rectangle - return self._Geometry.createRect(start, end, weight) - else: - center, radius = self._circleEquation(start, mid, end) - v = start - center - startAngle = numpy.angle(complex(v[0], v[1])) - v = mid - center - midAngle = numpy.angle(complex(v[0], v[1])) - v = end - center - endAngle = numpy.angle(complex(v[0], v[1])) - - # Is it clockwise or anticlockwise - relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi) - relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi) - if relativeMid < relativeEnd: - if endAngle < startAngle: - endAngle += 2 * numpy.pi - else: - if endAngle > startAngle: - endAngle -= 2 * numpy.pi - - return self._Geometry.create(center, start, end, - radius, weight, startAngle, endAngle) - - def _createShapeFromGeometry(self, geometry): - kind = geometry.getKind() - if kind == "rect": - # It is not an arc - # but we can display it as an intermediate shape - normal = (geometry.endPoint - geometry.startPoint) - normal = numpy.array((normal[1], -normal[0])) - distance = numpy.linalg.norm(normal) - if distance != 0: - normal /= distance - points = numpy.array([ - geometry.startPoint + normal * geometry.weight * 0.5, - geometry.endPoint + normal * geometry.weight * 0.5, - geometry.endPoint - normal * geometry.weight * 0.5, - geometry.startPoint - normal * geometry.weight * 0.5]) - elif kind == "point": - # It is not an arc - # but we can display it as an intermediate shape - # NOTE: At least 2 points are expected - points = numpy.array([geometry.startPoint, geometry.startPoint]) - elif kind == "circle": - outerRadius = geometry.radius + geometry.weight * 0.5 - angles = numpy.arange(0, 2 * numpy.pi, 0.1) - # It's a circle - points = [] - numpy.append(angles, angles[-1]) - for angle in angles: - direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) - points.append(geometry.center + direction * outerRadius) - points = numpy.array(points) - elif kind == "donut": - innerRadius = geometry.radius - geometry.weight * 0.5 - outerRadius = geometry.radius + geometry.weight * 0.5 - angles = numpy.arange(0, 2 * numpy.pi, 0.1) - # It's a donut - points = [] - # NOTE: NaN value allow to create 2 separated circle shapes - # using a single plot item. It's a kind of cheat - points.append(numpy.array([float("nan"), float("nan")])) - for angle in angles: - direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) - points.insert(0, geometry.center + direction * innerRadius) - points.append(geometry.center + direction * outerRadius) - points.append(numpy.array([float("nan"), float("nan")])) - points = numpy.array(points) - else: - innerRadius = geometry.radius - geometry.weight * 0.5 - outerRadius = geometry.radius + geometry.weight * 0.5 - - delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 - if geometry.startAngle == geometry.endAngle: - # Degenerated, it's a line (single radius) - angle = geometry.startAngle - direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) - points = [] - points.append(geometry.center + direction * innerRadius) - points.append(geometry.center + direction * outerRadius) - return numpy.array(points) - - angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta) - if angles[-1] != geometry.endAngle: - angles = numpy.append(angles, geometry.endAngle) - - if kind == "camembert": - # It's a part of camembert - points = [] - points.append(geometry.center) - points.append(geometry.startPoint) - delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1 - for angle in angles: - direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) - points.append(geometry.center + direction * outerRadius) - points.append(geometry.endPoint) - points.append(geometry.center) - elif kind == "arc": - # It's a part of donut - points = [] - points.append(geometry.startPoint) - for angle in angles: - direction = numpy.array([numpy.cos(angle), numpy.sin(angle)]) - points.insert(0, geometry.center + direction * innerRadius) - points.append(geometry.center + direction * outerRadius) - points.insert(0, geometry.endPoint) - points.append(geometry.endPoint) - else: - assert False - - points = numpy.array(points) - - return points - - def _updateShape(self): - geometry = self._geometry - points = self._createShapeFromGeometry(geometry) - self.__shape.setPoints(points) - - index = numpy.nanargmin(points[:, 1]) - pos = points[index] - with utils.blockSignals(self._handleLabel): - self._handleLabel.setPosition(pos[0], pos[1]) - - if geometry.center is None: - movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66 - elif (geometry.isClosed() - or abs(geometry.endAngle - geometry.startAngle) > numpy.pi * 0.7): - movePos = geometry.center - else: - moveAngle = geometry.startAngle * 0.34 + geometry.endAngle * 0.66 - vector = numpy.array([numpy.cos(moveAngle), numpy.sin(moveAngle)]) - movePos = geometry.center + geometry.radius * vector - - with utils.blockSignals(self._handleMove): - self._handleMove.setPosition(*movePos) - - self.sigRegionChanged.emit() - - def getGeometry(self): - """Returns a tuple containing the geometry of this ROI - - It is a symmetric function of :meth:`setGeometry`. - - If `startAngle` is smaller than `endAngle` the rotation is clockwise, - else the rotation is anticlockwise. - - :rtype: Tuple[numpy.ndarray,float,float,float,float] - :raise ValueError: In case the ROI can't be represented as section of - a circle - """ - geometry = self._geometry - if geometry.center is None: - raise ValueError("This ROI can't be represented as a section of circle") - return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle - - def isClosed(self): - """Returns true if the arc is a closed shape, like a circle or a donut. - - :rtype: bool - """ - return self._geometry.isClosed() - - def getCenter(self): - """Returns the center of the circle used to draw arcs of this ROI. - - This center is usually outside the the shape itself. - - :rtype: numpy.ndarray - """ - return self._geometry.center - - def getStartAngle(self): - """Returns the angle of the start of the section of this ROI (in radian). - - If `startAngle` is smaller than `endAngle` the rotation is clockwise, - else the rotation is anticlockwise. - - :rtype: float - """ - return self._geometry.startAngle - - def getEndAngle(self): - """Returns the angle of the end of the section of this ROI (in radian). - - If `startAngle` is smaller than `endAngle` the rotation is clockwise, - else the rotation is anticlockwise. - - :rtype: float - """ - return self._geometry.endAngle - - def getInnerRadius(self): - """Returns the radius of the smaller arc used to draw this ROI. - - :rtype: float - """ - geometry = self._geometry - radius = geometry.radius - geometry.weight * 0.5 - if radius < 0: - radius = 0 - return radius - - def getOuterRadius(self): - """Returns the radius of the bigger arc used to draw this ROI. - - :rtype: float - """ - geometry = self._geometry - radius = geometry.radius + geometry.weight * 0.5 - return radius - - def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle): - """ - Set the geometry of this arc. - - :param numpy.ndarray center: Center of the circle. - :param float innerRadius: Radius of the smaller arc of the section. - :param float outerRadius: Weight of the bigger arc of the section. - It have to be bigger than `innerRadius` - :param float startAngle: Location of the start of the section (in radian) - :param float endAngle: Location of the end of the section (in radian). - If `startAngle` is smaller than `endAngle` the rotation is clockwise, - else the rotation is anticlockwise. - """ - assert(innerRadius <= outerRadius) - assert(numpy.abs(startAngle - endAngle) <= 2 * numpy.pi) - center = numpy.array(center) - radius = (innerRadius + outerRadius) * 0.5 - weight = outerRadius - innerRadius - - vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)]) - startPoint = center + vector * radius - vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)]) - endPoint = center + vector * radius - - geometry = self._Geometry.create(center, startPoint, endPoint, - radius, weight, - startAngle, endAngle, closed=None) - self._geometry = geometry - self._updateHandles() - - @docstring(HandleBasedROI) - def contains(self, position): - # first check distance, fastest - center = self.getCenter() - distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2) - is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius() - if not is_in_distance: - return False - rel_pos = position[1] - center[1], position[0] - center[0] - angle = numpy.arctan2(*rel_pos) - start_angle = self.getStartAngle() - end_angle = self.getEndAngle() - - if start_angle < end_angle: - # I never succeed to find a condition where start_angle < end_angle - # so this is untested - is_in_angle = start_angle <= angle <= end_angle - else: - if end_angle < -numpy.pi and angle > 0: - angle = angle - (numpy.pi *2.0) - is_in_angle = end_angle <= angle <= start_angle - return is_in_angle - - def translate(self, x, y): - self._geometry = self._geometry.translated(x, y) - self._updateHandles() - - def _arcCurvatureMarkerConstraint(self, x, y): - """Curvature marker remains on perpendicular bisector""" - geometry = self._geometry - if geometry.center is None: - center = (geometry.startPoint + geometry.endPoint) * 0.5 - vector = geometry.startPoint - geometry.endPoint - vector = numpy.array((vector[1], -vector[0])) - vdist = numpy.linalg.norm(vector) - if vdist != 0: - normal = numpy.array((vector[1], -vector[0])) / vdist - else: - normal = numpy.array((0, 0)) - else: - if geometry.isClosed(): - midAngle = geometry.startAngle + numpy.pi * 0.5 - else: - midAngle = (geometry.startAngle + geometry.endAngle) * 0.5 - normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)]) - center = geometry.center - dist = numpy.dot(normal, (numpy.array((x, y)) - center)) - dist = numpy.clip(dist, geometry.radius, geometry.radius * 2) - x, y = center + dist * normal - return x, y - - @staticmethod - def _circleEquation(pt1, pt2, pt3): - """Circle equation from 3 (x, y) points - - :return: Position of the center of the circle and the radius - :rtype: Tuple[Tuple[float,float],float] - """ - x, y, z = complex(*pt1), complex(*pt2), complex(*pt3) - w = z - x - w /= y - x - c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x - return numpy.array((-c.real, -c.imag)), abs(c + x) - - def __str__(self): - try: - center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry() - params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle - params = 'center: %f %f; radius: %f %f; angles: %f %f' % params - except ValueError: - params = "invalid" - return "%s(%s)" % (self.__class__.__name__, params) - - class HorizontalRangeROI(RegionOfInterest, items.LineMixIn): """A ROI identifying an horizontal range in a 1D plot.""" @@ -2875,6 +1509,10 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn): marker = self.sender() self.setCenter(marker.getXPosition()) + @docstring(HandleBasedROI) + def contains(self, position): + return self.getMin() <= position[0] <= self.getMax() + def __str__(self): vrange = self.getRange() params = 'min: %f; max: %f' % vrange diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py index 5e7d65b..fd7cfae 100644 --- a/silx/gui/plot/items/scatter.py +++ b/silx/gui/plot/items/scatter.py @@ -332,6 +332,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): @docstring(ScatterVisualizationMixIn) def setVisualizationParameter(self, parameter, value): + parameter = self.VisualizationParameter.from_value(parameter) + if super(Scatter, self).setVisualizationParameter(parameter, value): if parameter in (self.VisualizationParameter.GRID_BOUNDS, self.VisualizationParameter.GRID_MAJOR_ORDER, @@ -339,8 +341,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): self.__cacheRegularGridInfo = None if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE, - self.VisualizationParameter.BINNED_STATISTIC_FUNCTION): - if parameter == self.VisualizationParameter.BINNED_STATISTIC_SHAPE: + self.VisualizationParameter.BINNED_STATISTIC_FUNCTION, + self.VisualizationParameter.DATA_BOUNDS_HINT): + if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE, + self.VisualizationParameter.DATA_BOUNDS_HINT): self.__cacheHistogramInfo = None # Clean-up cache if self.getVisualization() is self.Visualization.BINNED_STATISTIC: self._updateColormappedData() @@ -351,7 +355,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): @docstring(ScatterVisualizationMixIn) def getCurrentVisualizationParameter(self, parameter): value = self.getVisualizationParameter(parameter) - if value is not None: + if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or + value is not None): return value # Value has been set, return it elif parameter is self.VisualizationParameter.GRID_BOUNDS: @@ -452,6 +457,12 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): ranges = (tuple(min_max(y, finite=True)), tuple(min_max(x, finite=True))) + rangesHint = self.getVisualizationParameter( + self.VisualizationParameter.DATA_BOUNDS_HINT) + if rangesHint is not None: + ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax)) + for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint)) + points = numpy.transpose(numpy.array((y, x))) counts, sums, bin_edges = Histogramnd( points, @@ -850,7 +861,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn): if numpy.any(clipped): # copy to keep original array and convert to float - value = numpy.array(value, copy=True, dtype=numpy.float) + value = numpy.array(value, copy=True, dtype=numpy.float64) value[clipped] = numpy.nan x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive) diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py index 26aa03b..955dfe3 100644 --- a/silx/gui/plot/items/shape.py +++ b/silx/gui/plot/items/shape.py @@ -36,7 +36,9 @@ import numpy import six from ... import colors -from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn +from .core import ( + Item, DataItem, + ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn) _logger = logging.getLogger(__name__) @@ -154,7 +156,7 @@ class Shape(Item, ColorMixIn, FillMixIn, LineMixIn): self._updated(ItemChangedType.LINE_BG_COLOR) -class BoundingRect(Item, YAxisMixIn): +class BoundingRect(DataItem, YAxisMixIn): """An invisible shape which enforce the plot view to display the defined space on autoscale. @@ -166,21 +168,10 @@ class BoundingRect(Item, YAxisMixIn): """ def __init__(self): - Item.__init__(self) + DataItem.__init__(self) YAxisMixIn.__init__(self) self.__bounds = None - def _updated(self, event=None, checkVisibility=True): - if event in (ItemChangedType.YAXIS, - ItemChangedType.VISIBLE, - ItemChangedType.DATA): - # TODO hackish data range implementation - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - super(BoundingRect, self)._updated(event, checkVisibility) - def setBounds(self, rect): """Set the bounding box of this item in data coordinates @@ -193,6 +184,7 @@ class BoundingRect(Item, YAxisMixIn): if rect != self.__bounds: self.__bounds = rect + self._boundsChanged() self._updated(ItemChangedType.DATA) def _getBounds(self): @@ -217,7 +209,7 @@ class BoundingRect(Item, YAxisMixIn): return self.__bounds -class _BaseExtent(Item): +class _BaseExtent(DataItem): """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`. :param str axis: Either 'x' or 'y'. @@ -225,20 +217,10 @@ class _BaseExtent(Item): def __init__(self, axis='x'): assert axis in ('x', 'y') - Item.__init__(self) + DataItem.__init__(self) self.__axis = axis self.__range = 1., 100. - def _updated(self, event=None, checkVisibility=True): - if event in (ItemChangedType.VISIBLE, - ItemChangedType.DATA): - # TODO hackish data range implementation - plot = self.getPlot() - if plot is not None: - plot._invalidateDataRange() - - super(_BaseExtent, self)._updated(event, checkVisibility) - def setRange(self, min_, max_): """Set the range of the extent of this item in data coordinates. @@ -254,6 +236,7 @@ class _BaseExtent(Item): if range_ != self.__range: self.__range = range_ + self._boundsChanged() self._updated(ItemChangedType.DATA) def getRange(self): diff --git a/silx/gui/plot/matplotlib/__init__.py b/silx/gui/plot/matplotlib/__init__.py index f42bf53..e787240 100644 --- a/silx/gui/plot/matplotlib/__init__.py +++ b/silx/gui/plot/matplotlib/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# Copyright (c) 2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,49 +23,15 @@ # # ###########################################################################*/ -from __future__ import absolute_import - -"""This module initializes matplotlib and sets-up the backend to use. - -It MUST be imported prior to any other import of matplotlib. - -It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding -to the used backend. -""" - __authors__ = ["T. Vincent"] __license__ = "MIT" -__date__ = "02/05/2018" - - -from pkg_resources import parse_version -import matplotlib - -from ... import qt - - -def _matplotlib_use(backend, force): - """Wrapper of `matplotlib.use` to set-up backend. - - It adds extra initialization for PySide and PySide2 with matplotlib < 2.2. - """ - # This is kept for compatibility with matplotlib < 2.2 - if parse_version(matplotlib.__version__) < parse_version('2.2'): - if qt.BINDING == 'PySide': - matplotlib.rcParams['backend.qt4'] = 'PySide' - if qt.BINDING == 'PySide2': - matplotlib.rcParams['backend.qt5'] = 'PySide2' - - matplotlib.use(backend, force=force) - +__date__ = "15/07/2020" -if qt.BINDING in ('PyQt4', 'PySide'): - _matplotlib_use('Qt4Agg', force=False) - from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa +from silx.utils.deprecation import deprecated_warning -elif qt.BINDING in ('PyQt5', 'PySide2'): - _matplotlib_use('Qt5Agg', force=False) - from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa +deprecated_warning(type_='module', + name=__file__, + replacement='silx.gui.utils.matplotlib', + since_version='0.14.0') -else: - raise ImportError("Unsupported Qt binding: %s" % qt.BINDING) +from silx.gui.utils.matplotlib import FigureCanvasQTAgg # noqa diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py index ad61536..755b185 100644 --- a/silx/gui/plot/stats/stats.py +++ b/silx/gui/plot/stats/stats.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,7 +22,9 @@ # THE SOFTWARE. # # ###########################################################################*/ -"""This module provides the :class:`Scatter` item of the :class:`Plot`. +"""This module provides mechanism relative to stats calculation within a +:class:`PlotWidget`. +It also include the implementation of the statistics themselves. """ __authors__ = ["H. Payno"] @@ -31,13 +33,19 @@ __date__ = "06/06/2018" from collections import OrderedDict +from functools import lru_cache import logging import numpy +import numpy.ma from .. import items -from ....math.combo import min_max +from ..CurvesROIWidget import ROI +from ..items.roi import RegionOfInterest +from ....math.combo import min_max +from silx.utils.proxy import docstring +from ....utils.deprecation import deprecated logger = logging.getLogger(__name__) @@ -60,7 +68,8 @@ class Stats(OrderedDict): for stat in _statslist: self.add(stat) - def calculate(self, item, plot, onlimits): + def calculate(self, item, plot, onlimits, roi, data_changed=False, + roi_changed=False): """ Call all :class:`Stat` object registered and return the result of the computation. @@ -69,38 +78,31 @@ class Stats(OrderedDict): :param plot: plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: region of interest for statistic calculation. Incompatible + with the `onlimits` option. + :type roi: Union[None, :class:`~_RegionOfInterestBase`] + :param bool data_changed: did the data changed since last calculation. + :param bool roi_changed: did the associated roi (if any) has changed + since last calculation. :return dict: dictionary with :class:`Stat` name as ket and result of the calculation as value """ - context = None - # Check for PlotWidget items - if isinstance(item, items.Curve): - context = _CurveContext(item, plot, onlimits) - elif isinstance(item, items.ImageData): - context = _ImageContext(item, plot, onlimits) - elif isinstance(item, items.Scatter): - context = _ScatterContext(item, plot, onlimits) - elif isinstance(item, items.Histogram): - context = _HistogramContext(item, plot, onlimits) - else: - # Check for SceneWidget items - from ...plot3d import items as items3d # Lazy import - - if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)): - context = _plot3DScatterContext(item, plot, onlimits) - elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)): - context = _plot3DArrayContext(item, plot, onlimits) - - if context is None: - raise ValueError('Item type not managed') - res = {} + context = self._getContext(item=item, plot=plot, onlimits=onlimits, + roi=roi) for statName, stat in list(self.items()): if context.kind not in stat.compatibleKinds: logger.debug('kind %s not managed by statistic %s' % (context.kind, stat.name)) res[statName] = None else: + if roi_changed is True: + context.clear_mask() + if data_changed is True or roi_changed is True: + # if data changed or mask changed + context.clipData(item=item, plot=plot, onlimits=onlimits, + roi=roi) + # init roi and data res[statName] = stat.calculate(context) return res @@ -109,8 +111,40 @@ class Stats(OrderedDict): OrderedDict.__setitem__(self, key, value) def add(self, stat): + """Add a :class:`Stat` to the set + + :param Stat stat: stat to add to the set + """ self.__setitem__(key=stat.name, value=stat) + @staticmethod + @lru_cache(maxsize=50) + def _getContext(item, plot, onlimits, roi): + context = None + # Check for PlotWidget items + if isinstance(item, items.Curve): + context = _CurveContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.ImageData): + context = _ImageContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.Scatter): + context = _ScatterContext(item, plot, onlimits, roi=roi) + elif isinstance(item, items.Histogram): + context = _HistogramContext(item, plot, onlimits, roi=roi) + else: + # Check for SceneWidget items + from ...plot3d import items as items3d # Lazy import + + if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)): + context = _plot3DScatterContext(item, plot, onlimits, + roi=roi) + elif isinstance(item, + (items3d.ImageData, items3d.ScalarField3D)): + context = _plot3DArrayContext(item, plot, onlimits, + roi=roi) + if context is None: + raise ValueError('Item type not managed') + return context + class _StatsContext(object): """ @@ -127,8 +161,11 @@ class _StatsContext(object): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlimits` calculation + :type roi: Union[None,:class:`_RegionOfInterestBase`] """ - def __init__(self, item, kind, plot, onlimits): + def __init__(self, item, kind, plot, onlimits, roi): assert item assert plot assert type(onlimits) is bool @@ -136,9 +173,12 @@ class _StatsContext(object): self.min = None self.max = None self.data = None + self.roi = None + self.onlimits = onlimits self.values = None - """The array of data""" + """The array of data with limit filtering if any. Is a numpy.ma.array, + meaning that it embed the mask applied by the roi if any""" self.axes = None """A list of array of position on each axis. @@ -151,11 +191,69 @@ class _StatsContext(object): and the order is (x, y, z). """ - self.createContext(item, plot, onlimits) + self.clipData(item, plot, onlimits, roi=roi) + + def clipData(self, item, plot, onlimits, roi): + """ + Clip the data to the current mask to have accurate statistics + + :param item: item for whiwh we want to clip data + :param plot: plot containing the item + :param onlimits: do we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + :type roi: Union[None,:class:`_RegionOfInterestBase`] + """ + raise NotImplementedError() - def createContext(self, item, plot, onlimits): + def clear_mask(self): + """ + Remove the mask to force recomputation of it on next iteration + :return: + """ + raise NotImplementedError() + + @property + def mask(self): + if self.values is not None: + assert isinstance(self.values, numpy.ma.MaskedArray) + return self.values.mask + else: + return None + + @property + def is_mask_valid(self, **kwargs): + """Return if the mask is valid for the data or need to be recomputed""" + raise NotImplementedError("Base class") + + def _set_mask_validity(self, **kwargs): + """User to set some values that allows to define the mask properties + and boundaries""" raise NotImplementedError("Base class") + def clipData(self, item, plot, onlimits, roi): + """ + Function called before computing each statistics associated to this + context. It will insure the context for the (item, plot, onlimits, roi) + is created. + + :param item: item for which we want statistics + :param plot: plot containing the statistics + :param bool onlimits: True if we want to apply statistic only on + visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlimits` calculation + :type roi: Union[None,:class:`_RegionOfInterestBase`] + """ + raise NotImplementedError("Base class") + + @deprecated(reason="context are now stored and keep during stats life." + "So this function will be called only once", + replacement="clipData", since_version="0.13.0") + def createContext(self, item, plot, onlimits, roi): + return self.clipData(item=item, plot=plot, onlimits=onlimits, + roi=roi) + def isStructuredData(self): """Returns True if data as an array-like structure. @@ -184,8 +282,34 @@ class _StatsContext(object): else: return self.values.ndim == 1 + def _checkContextInputs(self, item, plot, onlimits, roi): + if roi is not None and onlimits is True: + raise ValueError('Stats context is unable to manage both a ROI' + 'and the `onlimits` option') + + +class _ScatterCurveHistoMixInContext(_StatsContext): + def __init__(self, kind, item, plot, onlimits, roi): + self.clear_mask() + _StatsContext.__init__(self, item=item, kind=kind, + plot=plot, onlimits=onlimits, roi=roi) -class _CurveContext(_StatsContext): + def _set_mask_validity(self, onlimits, from_, to_): + self._onlimits = onlimits + self._from_ = from_ + self._to_ = to_ + + def clear_mask(self): + self._onlimits = None + self._from_ = None + self._to_ = None + + def is_mask_valid(self, onlimits, from_, to_): + return (onlimits == self.onlimits and from_ == self._from_ and + to_ == self._to_) + + +class _CurveContext(_ScatterCurveHistoMixInContext): """ StatsContext for :class:`Curve` @@ -193,32 +317,63 @@ class _CurveContext(_StatsContext): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): - _StatsContext.__init__(self, kind='curve', item=item, - plot=plot, onlimits=onlimits) - - def createContext(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='curve', item=item, + plot=plot, onlimits=onlimits, + roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) + self.roi = roi + self.onlimits = onlimits xData, yData = item.getData(copy=True)[0:2] if onlimits: minX, maxX = plot.getXAxis().getLimits() - mask = (minX <= xData) & (xData <= maxX) + if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) yData = yData[mask] xData = xData[mask] + mask = numpy.zeros_like(yData) + elif roi: + minX, maxX = roi.getFrom(), roi.getTo() + if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) + mask = mask == 0 + mask = mask.astype(numpy.int32) + else: + mask = numpy.zeros_like(yData) self.xData = xData self.yData = yData - if len(yData) > 0: - self.min, self.max = min_max(yData) + self.values = numpy.ma.array(yData, mask=mask) + unmasked_data = self.values.compressed() + if len(unmasked_data) > 0: + self.min, self.max = min_max(unmasked_data) else: self.min, self.max = None, None self.data = (xData, yData) - self.values = yData + self.axes = (xData,) + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') -class _HistogramContext(_StatsContext): + +class _HistogramContext(_ScatterCurveHistoMixInContext): """ StatsContext for :class:`Histogram` @@ -226,32 +381,66 @@ class _HistogramContext(_StatsContext): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): - _StatsContext.__init__(self, kind='histogram', item=item, - plot=plot, onlimits=onlimits) - - def createContext(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='histogram', + item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) yData, edges = item.getData(copy=True)[0:2] xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment()) + if onlimits: minX, maxX = plot.getXAxis().getLimits() - mask = (minX <= xData) & (xData <= maxX) + if self.is_mask_valid(onlimits, from_=minX, to_=maxX): + mask = self.mask + else: + mask = (minX <= xData) & (xData <= maxX) + self._set_mask_validity(onlimits=True, from_=minX, to_=maxX) + elif roi: + if self.is_mask_valid(onlimits, from_=roi._fromdata, to_=roi._todata): + mask = self.mask + else: + mask = (roi._fromdata <= xData) & (xData <= roi._todata) + mask = mask == 0 + self._set_mask_validity(onlimits=True, from_=roi._fromdata, + to_=roi._todata) + else: + mask = numpy.zeros_like(self.data) + + if onlimits: yData = yData[mask] xData = xData[mask] + self.data = (xData, yData) + self.values = numpy.ma.array(yData, mask=mask) + self.axes = (xData,) + self.xData = xData self.yData = yData - if len(yData) > 0: - self.min, self.max = min_max(yData) + + unmasked_data = self.values.compressed() + if len(unmasked_data) > 0: + self.min, self.max = min_max(unmasked_data) else: self.min, self.max = None, None - self.data = (xData, yData) - self.values = yData - self.axes = (xData,) + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') -class _ScatterContext(_StatsContext): + +class _ScatterContext(_ScatterCurveHistoMixInContext): """StatsContext scatter plots. It supports :class:`~silx.gui.plot.items.Scatter`. @@ -260,12 +449,19 @@ class _ScatterContext(_StatsContext): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): - _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, - onlimits=onlimits) - - def createContext(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): + _ScatterCurveHistoMixInContext.__init__(self, kind='scatter', + item=item, plot=plot, + onlimits=onlimits, roi=roi) + + @docstring(_ScatterCurveHistoMixInContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) valueData = item.getValueData(copy=True) xData = item.getXData(copy=True) yData = item.getYData(copy=True) @@ -283,34 +479,89 @@ class _ScatterContext(_StatsContext): xData = xData[(minY <= yData) & (yData <= maxY)] yData = yData[(minY <= yData) & (yData <= maxY)] - if len(valueData) > 0: - self.min, self.max = min_max(valueData) + if roi: + if self.is_mask_valid(onlimits=onlimits, from_=roi.getFrom(), + to_=roi.getTo()): + mask = self.mask + else: + mask = (xData < roi.getFrom()) | (xData > roi.getTo()) else: - self.min, self.max = None, None + mask = numpy.zeros_like(xData) + self.data = (xData, yData, valueData) - self.values = valueData + self.values = numpy.ma.array(valueData, mask=mask) self.axes = (xData, yData) + unmasked_values = self.values.compressed() + if len(unmasked_values) > 0: + self.min, self.max = min_max(unmasked_values) + else: + self.min, self.max = None, None + + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, ROI): + raise TypeError('curve `context` can ony manage 1D roi') + class _ImageContext(_StatsContext): """StatsContext for images. It supports :class:`~silx.gui.plot.items.ImageData`. + :warning: behaviour of scale images: now the statistics are computed on + the entire data array (there is no sampling in the array or + interpolation regarding the scale). + This also mean that the result can differ from what is displayed. + But I guess there is no perfect behaviour. + + :warning: `isIn` functions for image context: for now have basically a + binary approach, the pixel is in a roi or not. To have a fully + 'correct behaviour' we should add a weight on stats calculation + to moderate the pixel value. + :param item: the item for which we want to compute the context :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): + self.clear_mask() _StatsContext.__init__(self, kind='image', item=item, - plot=plot, onlimits=onlimits) - - def createContext(self, item, plot, onlimits): + plot=plot, onlimits=onlimits, roi=roi) + + def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax + : float): + self._mask_x_min = xmin + self._mask_x_max = xmax + self._mask_y_min = ymin + self._mask_y_max = ymax + + def clear_mask(self): + self._mask_x_min = None + self._mask_x_max = None + self._mask_y_min = None + self._mask_y_max = None + + def is_mask_valid(self, xmin, xmax, ymin, ymax): + return (xmin == self._mask_x_min and xmax == self._mask_x_max and + ymin == self._mask_y_min and ymax == self._mask_y_max) + + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) self.origin = item.getOrigin() self.scale = item.getScale() self.data = item.getData(copy=True) + mask = numpy.zeros_like(self.data) + """mask use to know of the stat should be count in or not""" if onlimits: minX, maxX = plot.getXAxis().getLimits() @@ -324,21 +575,50 @@ class _ImageContext(_StatsContext): XMinBound = max(XMinBound, 0) YMinBound = max(YMinBound, 0) + if onlimits: if XMaxBound <= XMinBound or YMaxBound <= YMinBound: self.data = None else: self.data = self.data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1] - if self.data.size > 0: - self.min, self.max = min_max(self.data) + mask = numpy.zeros_like(self.data) + elif roi: + minX, maxX = 0, self.data.shape[1] + minY, maxY = 0, self.data.shape[0] + + XMinBound = max(minX, 0) + YMinBound = max(minY, 0) + XMaxBound = min(maxX, self.data.shape[1]) + YMaxBound = min(maxY, self.data.shape[0]) + + if self.is_mask_valid(xmin=XMinBound, xmax=XMaxBound, + ymin=YMinBound, ymax=YMaxBound): + mask = self.mask + else: + for x in range(XMinBound, XMaxBound): + for y in range(YMinBound, YMaxBound): + _x = (x * self.scale[0]) + self.origin[0] + _y = (y * self.scale[1]) + self.origin[1] + mask[y, x] = not roi.contains((_x, _y)) + self._set_mask_validity(xmin=XMinBound, xmax=XMaxBound, + ymin=YMinBound, ymax=YMaxBound) + self.values = numpy.ma.array(self.data, mask=mask) + if self.values.compressed().size > 0: + self.min, self.max = min_max(self.values.compressed()) else: self.min, self.max = None, None - self.values = self.data if self.values is not None: self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]), self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1])) + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + class _plot3DScatterContext(_StatsContext): """StatsContext for 3D scatter plots. @@ -350,16 +630,26 @@ class _plot3DScatterContext(_StatsContext): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): _StatsContext.__init__(self, kind='scatter', item=item, plot=plot, - onlimits=onlimits) + onlimits=onlimits, roi=roi) - def createContext(self, item, plot, onlimits): + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) if onlimits: raise RuntimeError("Unsupported plot %s" % str(plot)) - values = item.getValueData(copy=False) + if roi: + logger.warning("Roi are unsupported on volume for now") + mask = numpy.zeros_like(values) + else: + mask = numpy.zeros_like(values) if values is not None and len(values) > 0: self.values = values @@ -367,13 +657,20 @@ class _plot3DScatterContext(_StatsContext): if self.values.ndim == 3: axes.append(item.getZData(copy=False)) self.axes = tuple(axes) - self.min, self.max = min_max(self.values) + self.values = numpy.ma.array(self.values, mask=mask) else: self.values = None self.axes = None self.min, self.max = None, None + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + class _plot3DArrayContext(_StatsContext): """StatsContext for 3D scalar field and data image. @@ -385,26 +682,45 @@ class _plot3DArrayContext(_StatsContext): :param plot: the plot containing the item :param bool onlimits: True if we want to apply statistic only on visible data. + :param roi: Region of interest for computing the statistics. + For now, incompatible with `onlinits` calculation + :type roi: Union[None, :class:`ROI`] """ - def __init__(self, item, plot, onlimits): + def __init__(self, item, plot, onlimits, roi): _StatsContext.__init__(self, kind='image', item=item, plot=plot, - onlimits=onlimits) + onlimits=onlimits, roi=roi) - def createContext(self, item, plot, onlimits): + @docstring(_StatsContext) + def clipData(self, item, plot, onlimits, roi): + self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, + roi=roi) if onlimits: raise RuntimeError("Unsupported plot %s" % str(plot)) values = item.getData(copy=False) + if roi: + logger.warning("Roi are unsuported on volume for now") + mask = numpy.zeros_like(values) + else: + mask = numpy.zeros_like(values) if values is not None and len(values) > 0: self.values = values self.axes = tuple([numpy.arange(size) for size in self.values.shape]) self.min, self.max = min_max(self.values) + self.values = numpy.ma.array(self.values, mask=mask) else: self.values = None self.axes = None self.min, self.max = None, None + def _checkContextInputs(self, item, plot, onlimits, roi): + _StatsContext._checkContextInputs(self, item=item, plot=plot, + onlimits=onlimits, roi=roi) + + if roi is not None and not isinstance(roi, RegionOfInterest): + raise TypeError('curve `context` can ony manage 2D roi') + BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram' @@ -456,6 +772,7 @@ class Stat(StatBase): StatBase.__init__(self, name, kinds) self._fct = fct + @docstring(StatBase) def calculate(self, context): if context.values is not None: if context.kind in self.compatibleKinds: @@ -472,6 +789,7 @@ class StatMin(StatBase): def __init__(self): StatBase.__init__(self, name='min') + @docstring(StatBase) def calculate(self, context): return context.min @@ -481,6 +799,7 @@ class StatMax(StatBase): def __init__(self): StatBase.__init__(self, name='max') + @docstring(StatBase) def calculate(self, context): return context.max @@ -490,6 +809,7 @@ class StatDelta(StatBase): def __init__(self): StatBase.__init__(self, name='delta') + @docstring(StatBase) def calculate(self, context): return context.max - context.min @@ -506,14 +826,17 @@ class _StatCoord(StatBase): :param int index: Index in the flattened data array :rtype: List[int] """ - if context.isStructuredData(): + + axes = context.axes + + if context.isStructuredData() or context.roi: coordinates = [] - for axis in reversed(context.axes): + for axis in reversed(axes): coordinates.append(axis[index % len(axis)]) index = index // len(axis) return tuple(coordinates) else: - return tuple(axis[index] for axis in context.axes) + return tuple(axis[index] for axis in axes) class StatCoordMin(_StatCoord): @@ -521,13 +844,15 @@ class StatCoordMin(_StatCoord): def __init__(self): _StatCoord.__init__(self, name='coords min') + @docstring(StatBase) def calculate(self, context): if context.values is None or not context.isScalarData(): return None - index = numpy.argmin(context.values) + index = context.values.argmin() return self._indexToCoordinates(context, index) + @docstring(StatBase) def getToolTip(self, kind): return "Coordinates of the first minimum value of the data" @@ -537,13 +862,17 @@ class StatCoordMax(_StatCoord): def __init__(self): _StatCoord.__init__(self, name='coords max') + @docstring(StatBase) def calculate(self, context): if context.values is None or not context.isScalarData(): return None - index = numpy.argmax(context.values) + # TODO: the values should be a mask array by default, will be simpler + # if possible + index = context.values.argmax() return self._indexToCoordinates(context, index) + @docstring(StatBase) def getToolTip(self, kind): return "Coordinates of the first maximum value of the data" @@ -553,11 +882,12 @@ class StatCOM(StatBase): def __init__(self): StatBase.__init__(self, name='COM', description='Center of mass') + @docstring(StatBase) def calculate(self, context): if context.values is None or not context.isScalarData(): return None - values = numpy.array(context.values, dtype=numpy.float64) + values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64) sum_ = numpy.sum(values) if sum_ == 0.: return (numpy.nan,) * len(context.axes) @@ -573,5 +903,6 @@ class StatCOM(StatBase): return tuple( numpy.sum(axis * values) / sum_ for axis in context.axes) + @docstring(StatBase) def getToolTip(self, kind): return "Compute the center of mass of the dataset" diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py index f69daff..17578d8 100644 --- a/silx/gui/plot/stats/statshandler.py +++ b/silx/gui/plot/stats/statshandler.py @@ -22,7 +22,8 @@ # THE SOFTWARE. # # ###########################################################################*/ -""" +"""This module containts the classes relative to the management of statistics +display. """ __authors__ = ["H. Payno"] @@ -178,7 +179,8 @@ class StatsHandler(object): else: return self.formatters[name].format(val) - def calculate(self, item, plot, onlimits): + def calculate(self, item, plot, onlimits, roi=None, data_changed=False, + roi_changed=False): """ compute all statistic registered and return the list of formatted statistics result. @@ -187,10 +189,14 @@ class StatsHandler(object): :param plot: plot containing the item :param onlimits: True if we want to compute statistics on visible data only + :type: bool + :param roi: region of interest for statistic calculation + :type: Union[None,:class:`_RegionOfInterestBase`] :return: list of formatted statistics (as str) :rtype: dict """ - res = self.stats.calculate(item, plot, onlimits) + res = self.stats.calculate(item, plot, onlimits, roi, + data_changed=data_changed, roi_changed=roi_changed) for resName, resValue in list(res.items()): res[resName] = self.format(resName, res[resName]) return res diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py index 0477e2a..dfb7c2e 100644 --- a/silx/gui/plot/test/__init__.py +++ b/silx/gui/plot/test/__init__.py @@ -53,6 +53,7 @@ from . import testSaveAction from . import testScatterView from . import testPixelIntensityHistoAction from . import testCompareImages +from . import testRoiStatsWidget def suite(): @@ -86,5 +87,6 @@ def suite(): testScatterView.suite(), testPixelIntensityHistoAction.suite(), testCompareImages.suite(), + testRoiStatsWidget.suite(), ]) return test_suite diff --git a/silx/gui/plot/test/testComplexImageView.py b/silx/gui/plot/test/testComplexImageView.py index 051ec4d..4ac3488 100644 --- a/silx/gui/plot/test/testComplexImageView.py +++ b/silx/gui/plot/test/testComplexImageView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2019 European Synchrotron Radiation Facility +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -50,7 +50,7 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase): def testPlot2DComplex(self): """Test API of ComplexImageView widget""" - data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex) + data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64) self.plot.setData(data) self.plot.setKeepDataAspectRatio(True) self.plot.getPlot().resetZoom() @@ -76,11 +76,11 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase): self.qWait(100) # Test no data - self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex)) + self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64)) self.qWait(100) # Test float data - self.plot.setData(numpy.arange(100, dtype=numpy.float).reshape(10, 10)) + self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10)) self.qWait(100) diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py index 77c53a8..6a0ab8c 100644 --- a/silx/gui/plot/test/testCurvesROIWidget.py +++ b/silx/gui/plot/test/testCurvesROIWidget.py @@ -375,13 +375,13 @@ class TestRoiWidgetSignals(TestCaseQt): self.listener.clear() roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5) - self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.registerROI(roi1) self.assertEqual(self.listener.callCount(), 1) self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') self.listener.clear() roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5) - self.curves_roi_widget.roiTable.addRoi(roi2) + self.curves_roi_widget.roiTable.registerROI(roi2) self.assertEqual(self.listener.callCount(), 1) self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2') self.listener.clear() @@ -398,7 +398,7 @@ class TestRoiWidgetSignals(TestCaseQt): self.assertTrue(self.listener.arguments()[0][0]['current'] is None) self.listener.clear() - self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.registerROI(roi1) self.assertEqual(self.listener.callCount(), 1) self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear') self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1) @@ -415,7 +415,7 @@ class TestRoiWidgetSignals(TestCaseQt): """Test SigROISignal when modifying it""" self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True) roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) - self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.registerROI(roi1) self.curves_roi_widget.roiTable.setActiveRoi(roi1) # test modify the roi2 object @@ -450,7 +450,7 @@ class TestRoiWidgetSignals(TestCaseQt): def testSetActiveCurve(self): """Test sigRoiSignal when set an active curve""" roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5) - self.curves_roi_widget.roiTable.addRoi(roi1) + self.curves_roi_widget.roiTable.registerROI(roi1) self.curves_roi_widget.roiTable.setActiveRoi(roi1) self.listener.clear() self.plot.setActiveCurve('curve0') diff --git a/silx/gui/plot/test/testItem.py b/silx/gui/plot/test/testItem.py index ad739a2..8dacdea 100644 --- a/silx/gui/plot/test/testItem.py +++ b/silx/gui/plot/test/testItem.py @@ -35,6 +35,7 @@ import numpy from silx.gui.utils.testutils import SignalListener from silx.gui.plot.items import ItemChangedType +from silx.gui.plot import items from .utils import PlotWidgetTestCase @@ -242,11 +243,96 @@ class TestSymbol(PlotWidgetTestCase): self.assertEqual('Diamond', name) +class TestVisibleExtent(PlotWidgetTestCase): + """Test item's visible extent feature""" + + def testGetVisibleBounds(self): + """Test Item.getVisibleBounds""" + + # Create test items (with a bounding box of x: [1,3], y: [0,2]) + curve = items.Curve() + curve.setData((1, 2, 3), (0, 1, 2)) + + histogram = items.Histogram() + histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3)) + + image = items.ImageData() + image.setOrigin((1, 0)) + image.setData(numpy.arange(4).reshape(2, 2)) + + scatter = items.Scatter() + scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3)) + + bbox = items.BoundingRect() + bbox.setBounds((1, 3, 0, 2)) + + xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis() + for item in (curve, histogram, image, scatter, bbox): + with self.subTest(item=item): + xaxis.setLimits(0, 100) + yaxis.setLimits(0, 100) + self.plot.addItem(item) + self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.)) + + xaxis.setLimits(0.5, 2.5) + self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.)) + + yaxis.setLimits(0.5, 1.5) + self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5)) + + item.setVisible(False) + self.assertIsNone(item.getVisibleBounds()) + + self.plot.clear() + + def testVisibleExtentTracking(self): + """Test Item's visible extent tracking""" + image = items.ImageData() + image.setData(numpy.arange(6).reshape(2, 3)) + + listener = SignalListener() + image._sigVisibleBoundsChanged.connect(listener) + image._setVisibleBoundsTracking(True) + self.assertTrue(image._isVisibleBoundsTracking()) + + self.plot.addItem(image) + self.assertEqual(listener.callCount(), 1) + + self.plot.getXAxis().setLimits(0, 1) + self.assertEqual(listener.callCount(), 2) + + self.plot.hide() + self.qapp.processEvents() + # No event here + self.assertEqual(listener.callCount(), 2) + + self.plot.getXAxis().setLimits(1, 2) + # No event since PlotWidget is hidden, delayed to PlotWidget show + self.assertEqual(listener.callCount(), 2) + + self.plot.show() + self.qapp.processEvents() + # Receives delayed event now + self.assertEqual(listener.callCount(), 3) + + image.setOrigin((-1, -1)) + self.assertEqual(listener.callCount(), 4) + + image.setVisible(False) + image.setOrigin((0, 0)) + # No event since item is not visible + self.assertEqual(listener.callCount(), 4) + + image.setVisible(True) + # Receives delayed event now + self.assertEqual(listener.callCount(), 5) + + def suite(): test_suite = unittest.TestSuite() loadTests = unittest.defaultTestLoader.loadTestsFromTestCase - test_suite.addTest(loadTests(TestSigItemChangedSignal)) - test_suite.addTest(loadTests(TestSymbol)) + for klass in (TestSigItemChangedSignal, TestSymbol, TestVisibleExtent): + test_suite.addTest(loadTests(klass)) return test_suite diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py index a05c1be..2e8db55 100644 --- a/silx/gui/plot/test/testMaskToolsWidget.py +++ b/silx/gui/plot/test/testMaskToolsWidget.py @@ -84,10 +84,15 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): self.mouseMove(plot, pos=(0, 0)) self.mouseMove(plot, pos=pos0) - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos0) - self.mouseMove(plot, pos=(0, 0)) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.qapp.processEvents() + self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2)) self.mouseMove(plot, pos=pos1) - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseMove(plot, pos=(0, 0)) def _drawPolygon(self): """Draw a star polygon in the plot""" @@ -106,7 +111,9 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): for pos in star: self.mouseMove(plot, pos=pos) self.qapp.processEvents() - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos) + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) self.qapp.processEvents() def _drawPencil(self): diff --git a/silx/gui/plot/test/testPlotInteraction.py b/silx/gui/plot/test/testPlotInteraction.py index 335b1e4..7a30434 100644 --- a/silx/gui/plot/test/testPlotInteraction.py +++ b/silx/gui/plot/test/testPlotInteraction.py @@ -68,7 +68,11 @@ class TestSelectPolygon(PlotWidgetTestCase): for pos in polygon: self.mouseMove(plot, pos=pos) - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() self.plot.sigPlotSignal.disconnect(dump) return [args[0] for args in dump.received] diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py index 4ef6a72..f9d2281 100755 --- a/silx/gui/plot/test/testPlotWidget.py +++ b/silx/gui/plot/test/testPlotWidget.py @@ -43,7 +43,7 @@ from silx.test.utils import test_options from silx.gui import qt from silx.gui.plot import PlotWidget from silx.gui.plot.items.curve import CurveStyle -from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent +from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis from silx.gui.colors import Colormap from .utils import PlotWidgetTestCase @@ -326,6 +326,23 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): resetzoom=False) self.plot.resetZoom() + def testPlotColormapNaNColor(self): + self.plot.setKeepDataAspectRatio(False) + self.plot.setGraphTitle('Colormap with NaN color') + + colormap = Colormap() + colormap.setNaNColor('red') + self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0)) + data = DATA_2D.astype(numpy.float32) + data[len(data)//2:] = numpy.nan + self.plot.addImage(data, legend="image 1", colormap=colormap, + resetzoom=False) + self.plot.resetZoom() + + colormap.setNaNColor((0., 1., 0., 1.)) + self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0)) + self.qapp.processEvents() + def testImageOriginScale(self): """Test of image with different origin and scale""" self.plot.setGraphTitle('origin and scale') @@ -401,7 +418,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase): def testPlotBooleanImage(self): """Test that a boolean image is displayed and converted to int8.""" - data = numpy.zeros((10, 10), dtype=numpy.bool) + data = numpy.zeros((10, 10), dtype=bool) data[::2, ::2] = True self.plot.addImage(data, legend='boolean') @@ -438,6 +455,21 @@ class TestPlotCurve(PlotWidgetTestCase): self.plot.setActiveCurveHandling(False) + def testPlotCurveInfinite(self): + """Test plot curves with not finite data""" + tests = { + 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]), + 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]), + 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]), + 'y some inf': ([0, 1, 2], [0, numpy.inf, 2]) + } + for name, args in tests.items(): + with self.subTest(name): + self.plot.addCurve(*args) + self.plot.resetZoom() + self.qapp.processEvents() + self.plot.clear() + def testPlotCurveColorFloat(self): color = numpy.array(numpy.random.random(3 * 1000), dtype=numpy.float32).reshape(1000, 3) @@ -799,17 +831,25 @@ class TestPlotItem(PlotWidgetTestCase): """Basic tests for addItem.""" # Polygon coordinates and color - polygons = [ # legend, x coords, y coords, color + POLYGONS = [ # legend, x coords, y coords, color ('triangle', numpy.array((10, 30, 50)), numpy.array((55, 70, 55)), 'red'), ('square', numpy.array((10, 10, 50, 50)), numpy.array((10, 50, 50, 10)), 'green'), ('star', numpy.array((60, 70, 80, 60, 80)), numpy.array((25, 50, 25, 40, 40)), 'blue'), + ('2 triangles-simple', + numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)), + numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)), + 'pink'), + ('2 triangles-extra NaN', + numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)), + numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)), + 'black'), ] # Rectangle coordinantes and color - rectangles = [ # legend, x coords, y coords, color + RECTANGLES = [ # legend, x coords, y coords, color ('square 1', numpy.array((1., 10.)), numpy.array((1., 10.)), 'red'), ('square 2', numpy.array((10., 20.)), @@ -822,6 +862,8 @@ class TestPlotItem(PlotWidgetTestCase): numpy.array((45., 45.)), 'darkRed'), ] + SCALES = Axis.LINEAR, Axis.LOGARITHMIC + def setUp(self): super(TestPlotItem, self).setUp() @@ -833,40 +875,60 @@ class TestPlotItem(PlotWidgetTestCase): self.plot.setLimits(0., 100., -100., 100.) def testPlotItemPolygonFill(self): - self.plot.setGraphTitle('Item Fill') - - for legend, xList, yList, color in self.polygons: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="polygon", fill=True, color=color) - self.plot.resetZoom() + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Item Fill %s' % scale) + + for legend, xList, yList, color in self.POLYGONS: + self.plot.addShape(xList, yList, legend=legend, + replace=False, linestyle='--', + shape="polygon", fill=True, color=color) + self.plot.resetZoom() def testPlotItemPolygonNoFill(self): - self.plot.setGraphTitle('Item No Fill') - - for legend, xList, yList, color in self.polygons: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="polygon", fill=False, color=color) - self.plot.resetZoom() + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Item No Fill %s' % scale) + + for legend, xList, yList, color in self.POLYGONS: + self.plot.addShape(xList, yList, legend=legend, + replace=False, linestyle='--', + shape="polygon", fill=False, color=color) + self.plot.resetZoom() def testPlotItemRectangleFill(self): - self.plot.setGraphTitle('Rectangle Fill') - - for legend, xList, yList, color in self.rectangles: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="rectangle", fill=True, color=color) - self.plot.resetZoom() + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Rectangle Fill %s' % scale) + + for legend, xList, yList, color in self.RECTANGLES: + self.plot.addShape(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=True, color=color) + self.plot.resetZoom() def testPlotItemRectangleNoFill(self): - self.plot.setGraphTitle('Rectangle No Fill') - - for legend, xList, yList, color in self.rectangles: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="rectangle", fill=False, color=color) - self.plot.resetZoom() + for scale in self.SCALES: + with self.subTest(scale=scale): + self.plot.clear() + self.plot.getXAxis().setScale(scale) + self.plot.getYAxis().setScale(scale) + self.plot.setGraphTitle('Rectangle No Fill %s' % scale) + + for legend, xList, yList, color in self.RECTANGLES: + self.plot.addShape(xList, yList, legend=legend, + replace=False, + shape="rectangle", fill=False, color=color) + self.plot.resetZoom() class TestPlotActiveCurveImage(PlotWidgetTestCase): @@ -1384,6 +1446,20 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase): """Test coverage on setAxesDisplayed(True)""" self.plot.setAxesDisplayed(True) + def testAxesMargins(self): + """Test PlotWidget's getAxesMargins and setAxesMargins""" + self.plot.show() + self.qWaitForWindowExposed(self.plot) + + margins = self.plot.getAxesMargins() + self.assertEqual(margins, (.15, .1, .1, .15)) + + for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)): + with self.subTest(margins=margins): + self.plot.setAxesMargins(*margins) + self.qapp.processEvents() + self.assertEqual(self.plot.getAxesMargins(), margins) + def testBoundingRectItem(self): item = BoundingRect() item.setBounds((-1000, 1000, -2000, 2000)) @@ -1752,80 +1828,33 @@ class TestPlotMarkerLog(PlotWidgetTestCase): self.plot.resetZoom() -class TestPlotItemLog(PlotWidgetTestCase): - """Basic tests for items with log scale axes""" +class TestPlotWidgetSwitchBackend(PlotWidgetTestCase): + """Test [get|set]Backend to switch backend""" - # Polygon coordinates and color - polygons = [ # legend, x coords, y coords, color - ('triangle', numpy.array((10, 30, 50)), - numpy.array((55, 70, 55)), 'red'), - ('square', numpy.array((10, 10, 50, 50)), - numpy.array((10, 50, 50, 10)), 'green'), - ('star', numpy.array((60, 70, 80, 60, 80)), - numpy.array((25, 50, 25, 40, 40)), 'blue'), - ] - - # Rectangle coordinantes and color - rectangles = [ # legend, x coords, y coords, color - ('square 1', numpy.array((1., 10.)), - numpy.array((1., 10.)), 'red'), - ('square 2', numpy.array((10., 20.)), - numpy.array((10., 20.)), 'green'), - ('square 3', numpy.array((20., 30.)), - numpy.array((20., 30.)), 'blue'), - ('rect 1', numpy.array((1., 30.)), - numpy.array((35., 40.)), 'black'), - ('line h', numpy.array((1., 30.)), - numpy.array((45., 45.)), 'darkRed'), - ] - - def setUp(self): - super(TestPlotItemLog, self).setUp() + def testSwitchBackend(self): + """Test switching a plot with a few items""" + backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'} + if test_options.WITH_GL_TEST: + backends['gl'] = 'BackendOpenGL' - self.plot.getYAxis().setLabel('Rows') - self.plot.getXAxis().setLabel('Columns') - self.plot.getXAxis().setAutoScale(False) - self.plot.getYAxis().setAutoScale(False) - self.plot.setKeepDataAspectRatio(False) - self.plot.setLimits(1., 100., 1., 100.) - self.plot.getXAxis()._setLogarithmic(True) - self.plot.getYAxis()._setLogarithmic(True) - - def testPlotItemPolygonLogFill(self): - self.plot.setGraphTitle('Item Fill Log') - - for legend, xList, yList, color in self.polygons: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="polygon", fill=True, color=color) - self.plot.resetZoom() - - def testPlotItemPolygonLogNoFill(self): - self.plot.setGraphTitle('Item No Fill Log') - - for legend, xList, yList, color in self.polygons: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="polygon", fill=False, color=color) - self.plot.resetZoom() - - def testPlotItemRectangleLogFill(self): - self.plot.setGraphTitle('Rectangle Fill Log') - - for legend, xList, yList, color in self.rectangles: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="rectangle", fill=True, color=color) + self.plot.addImage(numpy.arange(100).reshape(10, 10)) + self.plot.addCurve((-3, -2, -1), (1, 2, 3)) self.plot.resetZoom() + xlimits = self.plot.getXAxis().getLimits() + ylimits = self.plot.getYAxis().getLimits() + items = self.plot.getItems() + self.assertEqual(len(items), 2) - def testPlotItemRectangleLogNoFill(self): - self.plot.setGraphTitle('Rectangle No Fill Log') + for backend, className in backends.items(): + with self.subTest(backend=backend): + self.plot.setBackend(backend) + self.plot.replot() - for legend, xList, yList, color in self.rectangles: - self.plot.addShape(xList, yList, legend=legend, - replace=False, - shape="rectangle", fill=False, color=color) - self.plot.resetZoom() + retrievedBackend = self.plot.getBackend() + self.assertEqual(type(retrievedBackend).__name__, className) + self.assertEqual(self.plot.getXAxis().getLimits(), xlimits) + self.assertEqual(self.plot.getYAxis().getLimits(), ylimits) + self.assertEqual(self.plot.getItems(), items) def suite(): @@ -1841,8 +1870,7 @@ def suite(): TestPlotEmptyLog, TestPlotCurveLog, TestPlotImageLog, - TestPlotMarkerLog, - TestPlotItemLog) + TestPlotMarkerLog) test_suite = unittest.TestSuite() @@ -1859,6 +1887,9 @@ def suite(): for testClass in testClasses: test_suite.addTest(parameterize(testClass, backend='gl')) + test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase( + TestPlotWidgetSwitchBackend)) + return test_suite diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py index 8e7b35c..e12b756 100644 --- a/silx/gui/plot/test/testPlotWindow.py +++ b/silx/gui/plot/test/testPlotWindow.py @@ -33,12 +33,12 @@ import unittest import numpy from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction +from silx.test.utils import test_options from silx.gui import qt from silx.gui.plot import PlotWindow from silx.gui.colors import Colormap - class TestPlotWindow(TestCaseQt): """Base class for tests of PlotWindow.""" @@ -155,6 +155,25 @@ class TestPlotWindow(TestCaseQt): self.assertEqual(self._count, 1) del self._count + @unittest.skipUnless(test_options.WITH_GL_TEST, + test_options.WITH_QT_TEST_REASON) + def testSwitchBackend(self): + """Test switching an empty plot""" + self.plot.resetZoom() + xlimits = self.plot.getXAxis().getLimits() + ylimits = self.plot.getYAxis().getLimits() + isKeepAspectRatio = self.plot.isKeepDataAspectRatio() + + for backend in ('gl', 'mpl'): + with self.subTest(): + self.plot.setBackend(backend) + self.plot.replot() + self.assertEqual(self.plot.getXAxis().getLimits(), xlimits) + self.assertEqual(self.plot.getYAxis().getLimits(), ylimits) + self.assertEqual( + self.plot.isKeepDataAspectRatio(), isKeepAspectRatio) + + def suite(): test_suite = unittest.TestSuite() test_suite.addTest( diff --git a/silx/gui/plot/test/testRoiStatsWidget.py b/silx/gui/plot/test/testRoiStatsWidget.py new file mode 100644 index 0000000..378d499 --- /dev/null +++ b/silx/gui/plot/test/testRoiStatsWidget.py @@ -0,0 +1,290 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for ROIStatsWidget""" + + +from silx.gui.utils.testutils import TestCaseQt +from silx.gui import qt +from silx.gui.plot import PlotWindow +from silx.gui.plot.stats.stats import Stats +from silx.gui.plot.ROIStatsWidget import ROIStatsWidget +from silx.gui.plot.CurvesROIWidget import ROI +from silx.gui.plot.items.roi import RectangleROI, PolygonROI +from silx.gui.plot.StatsWidget import UpdateMode +import unittest +import numpy + + + +class _TestRoiStatsBase(TestCaseQt): + """Base class for several unittest relative to ROIStatsWidget""" + def setUp(self): + TestCaseQt.setUp(self) + # define plot + self.plot = PlotWindow() + self.plot.addImage(numpy.arange(10000).reshape(100, 100), + legend='img1') + self.img_item = self.plot.getImage('img1') + self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56), + legend='curve1') + self.curve_item = self.plot.getCurve('curve1') + self.plot.addHistogram(edges=numpy.linspace(0, 10, 56), + histogram=numpy.arange(56), legend='histo1') + self.histogram_item = self.plot.getHistogram(legend='histo1') + self.plot.addScatter(x=numpy.linspace(0, 10, 56), + y=numpy.linspace(0, 10, 56), + value=numpy.arange(56), + legend='scatter1') + self.scatter_item = self.plot.getScatter(legend='scatter1') + + # stats widget + self.statsWidget = ROIStatsWidget(plot=self.plot) + + # define stats + stats = [ + ('sum', numpy.sum), + ('mean', numpy.mean), + ] + self.statsWidget.setStats(stats=stats) + + # define rois + self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy') + self.rectangle_roi = RectangleROI() + self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20)) + self.rectangle_roi.setName('Initial ROI') + self.polygon_roi = PolygonROI() + points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]]) + self.polygon_roi.setPoints(points) + + def statsTable(self): + return self.statsWidget._statsROITable + + def tearDown(self): + Stats._getContext.cache_clear() + self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.statsWidget.close() + self.statsWidget = None + self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True) + self.plot.close() + self.plot = None + TestCaseQt.tearDown(self) + + +class TestRoiStatsCouple(_TestRoiStatsBase): + """ + Test different possible couple (roi, plotItem). + Check that: + + * computation is correct if couple is valid + * raise an error if couple is invalid + """ + def testROICurve(self): + """ + Test that the couple (ROI, curveItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.curve_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + def testRectangleImage(self): + """ + Test that the couple (RectangleROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + assert item is not None + self.plot.addImage(numpy.ones(10000).reshape(100, 100), + legend='img1') + self.qapp.processEvents() + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), str(float(21*21))) + self.assertEqual(tableItems['mean'].text(), '1.0') + + def testPolygonImage(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.polygon_roi, + plotItem=self.img_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '22750') + self.assertEqual(tableItems['mean'].text(), '455.0') + + def testROIImage(self): + """ + Test that the couple (ROI, imageItem) is raising an error + """ + with self.assertRaises(TypeError): + self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.img_item) + + def testRectangleCurve(self): + """ + Test that the couple (rectangleROI, curveItem) is raising an error + """ + with self.assertRaises(TypeError): + self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.curve_item) + + def testROIHistogram(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + def testROIScatter(self): + """ + Test that the couple (PolygonROI, imageItem) can be used for stats + """ + item = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.scatter_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '253') + self.assertEqual(tableItems['mean'].text(), '11.0') + + +class TestRoiStatsAddRemoveItem(_TestRoiStatsBase): + """Test adding and removing (roi, plotItem) items""" + def testAddRemoveItems(self): + item1 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.scatter_item) + self.assertTrue(item1 is not None) + self.assertEqual(self.statsTable().rowCount(), 1) + item2 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + self.assertTrue(item2 is not None) + self.assertEqual(self.statsTable().rowCount(), 2) + # try to add twice the same item + item3 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.histogram_item) + self.assertTrue(item3 is None) + self.assertEqual(self.statsTable().rowCount(), 2) + item4 = self.statsWidget.addItem(roi=self.roi1D, + plotItem=self.curve_item) + self.assertTrue(item4 is not None) + self.assertEqual(self.statsTable().rowCount(), 3) + + self.statsWidget.removeItem(plotItem=item4._plot_item, + roi=item4._roi) + self.assertEqual(self.statsTable().rowCount(), 2) + # try to remove twice the same item + self.statsWidget.removeItem(plotItem=item4._plot_item, + roi=item4._roi) + self.assertEqual(self.statsTable().rowCount(), 2) + self.statsWidget.removeItem(plotItem=item2._plot_item, + roi=item2._roi) + self.statsWidget.removeItem(plotItem=item1._plot_item, + roi=item1._roi) + self.assertEqual(self.statsTable().rowCount(), 0) + + +class TestRoiStatsRoiUpdate(_TestRoiStatsBase): + """Test that the stats will be updated if the roi is updated""" + def testChangeRoi(self): + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '445410') + self.assertEqual(tableItems['mean'].text(), '1010.0') + + # update roi + self.rectangle_roi.setOrigin(position=(10, 10)) + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + + def testUpdateModeScenario(self): + """Test update according to a simple scenario""" + self.statsWidget._setUpdateMode(UpdateMode.AUTO) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['sum'].text(), '445410') + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._setUpdateMode(UpdateMode.MANUAL) + self.rectangle_roi.setOrigin(position=(10, 10)) + self.qapp.processEvents() + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._updateAllStats(is_request=True) + self.assertNotEqual(tableItems['sum'].text(), '445410') + self.assertNotEqual(tableItems['mean'].text(), '1010.0') + + +class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase): + """Test that the stats will be updated if the plot item is updated""" + def testChangeImage(self): + self.statsWidget._setUpdateMode(UpdateMode.AUTO) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['mean'].text(), '1010.0') + + # update plot + self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), + legend='img1') + self.assertNotEqual(tableItems['mean'].text(), '1059.5') + + def testUpdateModeScenario(self): + """Test update according to a simple scenario""" + self.statsWidget._setUpdateMode(UpdateMode.MANUAL) + item = self.statsWidget.addItem(roi=self.rectangle_roi, + plotItem=self.img_item) + + assert item is not None + tableItems = self.statsTable()._itemToTableItems(item) + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), + legend='img1') + self.assertEqual(tableItems['mean'].text(), '1010.0') + self.statsWidget._updateAllStats(is_request=True) + self.assertEqual(tableItems['mean'].text(), '1110.0') + + +def suite(): + test_suite = unittest.TestSuite() + for TestClass in (TestRoiStatsCouple, TestRoiStatsRoiUpdate, + TestRoiStatsPlotItemUpdate): + test_suite.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py index 171ec42..800f30e 100644 --- a/silx/gui/plot/test/testScatterMaskToolsWidget.py +++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py @@ -86,10 +86,16 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): self.mouseMove(plot, pos=(0, 0)) self.mouseMove(plot, pos=pos0) - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos0) - self.mouseMove(plot, pos=(0, 0)) + self.qapp.processEvents() + self.mousePress(plot, qt.Qt.LeftButton, pos=pos0) + self.qapp.processEvents() + + self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2)) self.mouseMove(plot, pos=pos1) - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1) + self.qapp.processEvents() + self.mouseMove(plot, pos=(0, 0)) def _drawPolygon(self): """Draw a star polygon in the plot""" @@ -108,7 +114,9 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase): for pos in star: self.mouseMove(plot, pos=pos) self.qapp.processEvents() - self.mouseClick(plot, qt.Qt.LeftButton, pos=pos) + self.mousePress(plot, qt.Qt.LeftButton, pos=pos) + self.qapp.processEvents() + self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos) self.qapp.processEvents() def _drawPencil(self): diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py index 80c85d6..7605bbc 100644 --- a/silx/gui/plot/test/testStackView.py +++ b/silx/gui/plot/test/testStackView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -60,6 +60,19 @@ class TestStackView(TestCaseQt): del self.stackview super(TestStackView, self).tearDown() + def testScaleColormapRangeToStack(self): + """Test scaleColormapRangeToStack""" + self.stackview.setStack(self.mystack) + self.stackview.setColormap("viridis") + colormap = self.stackview.getColormap() + + # Colormap autoscale to image + self.assertEqual(colormap.getVRange(), (None, None)) + self.stackview.scaleColormapRangeToStack() + + # Colormap range set according to stack range + self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max())) + def testSetStack(self): self.stackview.setStack(self.mystack) self.stackview.setColormap("viridis", autoscale=True) diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index 8db8cc9..d5046ba 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -35,6 +35,11 @@ from silx.gui.plot import StatsWidget from silx.gui.plot.stats import statshandler from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import Plot1D, Plot2D +from silx.gui.plot3d.SceneWidget import SceneWidget +from silx.gui.plot.items.roi import RectangleROI, PolygonROI +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.stats.stats import Stats +from silx.gui.plot.CurvesROIWidget import ROI from silx.utils.testutils import ParametricTestCase import unittest import logging @@ -43,12 +48,9 @@ import numpy _logger = logging.getLogger(__name__) -class TestStats(TestCaseQt): - """ - Test :class:`BaseClass` class and inheriting classes - """ +class TestStatsBase(object): + """Base class for stats TestCase""" def setUp(self): - TestCaseQt.setUp(self) self.createCurveContext() self.createImageContext() self.createScatterContext() @@ -63,7 +65,6 @@ class TestStats(TestCaseQt): self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) self.scatterPlot.close() del self.scatterPlot - TestCaseQt.tearDown(self) def createCurveContext(self): self.plot1d = Plot1D() @@ -74,12 +75,13 @@ class TestStats(TestCaseQt): self.curveContext = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=False) + onlimits=False, + roi=None) def createScatterContext(self): self.scatterPlot = Plot2D() lgd = 'scatter plot' - self.xScatterData = numpy.array([0, 1, 2, 20, 50, 60, 36]) + self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36]) self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18]) self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5]) self.scatterPlot.addScatter(self.xScatterData, self.yScatterData, @@ -87,7 +89,8 @@ class TestStats(TestCaseQt): self.scatterContext = stats._ScatterContext( item=self.scatterPlot.getScatter(lgd), plot=self.scatterPlot, - onlimits=False + onlimits=False, + roi=None ) def createImageContext(self): @@ -99,7 +102,8 @@ class TestStats(TestCaseQt): self.imageContext = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, - onlimits=False + onlimits=False, + roi=None ) def getBasicStats(self): @@ -113,6 +117,19 @@ class TestStats(TestCaseQt): 'com': stats.StatCOM() } + +class TestStats(TestStatsBase, TestCaseQt): + """ + Test :class:`BaseClass` class and inheriting classes + """ + def setUp(self): + TestCaseQt.setUp(self) + TestStatsBase.setUp(self) + + def tearDown(self): + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + def testBasicStatsCurve(self): """Test result for simple stats on a curve""" _stats = self.getBasicStats() @@ -155,7 +172,8 @@ class TestStats(TestCaseQt): image2Context = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, - onlimits=False + onlimits=False, + roi=None, ) _stats = self.getBasicStats() self.assertEqual(_stats['min'].calculate(image2Context), 0) @@ -225,21 +243,24 @@ class TestStats(TestCaseQt): curveContextOnLimits = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(curveContextOnLimits), 2) self.plot2d.getXAxis().setLimitsConstraints(minPos=32) imageContextOnLimits = stats._ImageContext( item=self.plot2d.getImage('test image'), plot=self.plot2d, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(imageContextOnLimits), 32) self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) scatterContextOnLimits = stats._ScatterContext( item=self.scatterPlot.getScatter('scatter plot'), plot=self.scatterPlot, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(scatterContextOnLimits), 20) @@ -255,7 +276,8 @@ class TestStatsFormatter(TestCaseQt): self.curveContext = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=False) + onlimits=False, + roi=None) self.stat = stats.StatMin() @@ -295,6 +317,7 @@ class TestStatsHandler(TestCaseQt): self.stat = stats.StatMin() def tearDown(self): + Stats._getContext.cache_clear() self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot1d.close() self.plot1d = None @@ -391,6 +414,7 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): self.statsTable.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.statsTable = None @@ -493,7 +517,6 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): self.qapp.processEvents() tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) curve0_min = tableItems['min'].text() - print(curve0_min) self.assertTrue(float(curve0_min) == -1.) self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) @@ -581,6 +604,7 @@ class TestStatsWidgetWithImages(TestCaseQt): self.widget.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) @@ -641,6 +665,7 @@ class TestStatsWidgetWithScatters(TestCaseQt): self.widget.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) self.scatterPlot.close() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) @@ -694,6 +719,7 @@ class TestLineWidget(TestCaseQt): stats=mystats) def tearDown(self): + Stats._getContext.cache_clear() self.qapp.processEvents() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() @@ -806,12 +832,223 @@ class TestUpdateModeWidget(TestCaseQt): self.assertEqual(manualUpdateListener.callCount(), 2) +class TestStatsROI(TestStatsBase, TestCaseQt): + """ + Test stats based on ROI + """ + def setUp(self): + TestCaseQt.setUp(self) + self.createRois() + TestStatsBase.setUp(self) + self.createHistogramContext() + + self.roiManager = RegionOfInterestManager(self.plot2d) + self.roiManager.addRoi(self._2Droi_rect) + self.roiManager.addRoi(self._2Droi_poly) + + def tearDown(self): + self.roiManager.clear() + self.roiManager = None + self._1Droi = None + self._2Droi_rect = None + self._2Droi_poly = None + self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plotHisto.close() + self.plotHisto = None + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + + def createRois(self): + self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0) + self._2Droi_rect = RectangleROI() + self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0)) + self._2Droi_poly = PolygonROI() + points = numpy.array(((0, 20), (0, 0), (10, 0))) + self._2Droi_poly.setPoints(points=points) + + def createCurveContext(self): + TestStatsBase.createCurveContext(self) + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._1Droi) + + def createHistogramContext(self): + self.plotHisto = Plot1D() + x = range(20) + y = range(20) + self.plotHisto.addHistogram(x, y, legend='histo0') + + self.histoContext = stats._HistogramContext( + item=self.plotHisto.getHistogram('histo0'), + plot=self.plotHisto, + onlimits=False, + roi=self._1Droi) + + def createScatterContext(self): + TestStatsBase.createScatterContext(self) + self.scatterContext = stats._ScatterContext( + item=self.scatterPlot.getScatter('scatter plot'), + plot=self.scatterPlot, + onlimits=False, + roi=self._1Droi + ) + + def createImageContext(self): + TestStatsBase.createImageContext(self) + + self.imageContext = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_rect + ) + + self.imageContext_2 = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_poly + ) + + def testErrors(self): + # test if onlimits is True and give also a roi + with self.assertRaises(ValueError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=True, + roi=self._1Droi) + + # test if is a curve context and give an invalid 2D roi + with self.assertRaises(TypeError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._2Droi_rect) + + def testBasicStatsCurve(self): + """Test result for simple stats on a curve""" + _stats = self.getBasicStats() + xData = yData = numpy.array(range(0, 10)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 2) + self.assertEqual(_stats['max'].calculate(self.curveContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6])) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6])) + com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6]) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) + + def testBasicStatsImageRectRoi(self): + """Test result for simple stats on an image""" + self.assertEqual(self.imageContext.values.compressed().size, 121) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext), 10) + self.assertEqual(_stats['max'].calculate(self.imageContext), 1300) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0)) + self.assertAlmostEqual(_stats['std'].calculate(self.imageContext), + numpy.std(self.imageData[0:11, 10:21])) + self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext), + numpy.mean(self.imageData[0:11, 10:21])) + + compressed_values = self.imageContext.values.compressed() + compressed_values = compressed_values.reshape(11, 11) + yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1) + xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0) + + dataYRange = range(11) + dataXRange = range(10, 21) + + ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + + def testBasicStatsImagePolyRoi(self): + """Test a simple rectangle ROI""" + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0)) + # not 0.0, 19.0 because not fully in. Should all pixel have a weight, + # on to manage them in stats. For now 0 if the center is not in, else 1 + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0)) + + def testBasicStatsScatter(self): + self.assertEqual(self.scatterContext.values.compressed().size, 2) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.scatterContext), 6) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 7) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7])) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7])) + + def testBasicHistogram(self): + _stats = self.getBasicStats() + xData = yData = numpy.array(range(2, 6)) + self.assertEqual(_stats['min'].calculate(self.histoContext), 2) + self.assertEqual(_stats['max'].calculate(self.histoContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData)) + com = numpy.sum(xData * yData) / numpy.sum(yData) + self.assertEqual(_stats['com'].calculate(self.histoContext), com) + + +class TestAdvancedROIImageContext(TestCaseQt): + """Test stats result on an image context with different scale and + origins""" + + def setUp(self): + TestCaseQt.setUp(self) + self.data_dims = (100, 100) + self.data = numpy.random.rand(*self.data_dims) + self.plot = Plot2D() + + def test(self): + """Test stats result on an image context with different scale and + origins""" + roi_origins = [(0, 0), (2, 10), (14, 20)] + img_origins = [(0, 0), (14, 20), (2, 10)] + img_scales = [1.0, 0.5, 2.0] + _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), } + for roi_origin in roi_origins: + for img_origin in img_origins: + for img_scale in img_scales: + with self.subTest(roi_origin=roi_origin, + img_origin=img_origin, + img_scale=img_scale): + self.plot.addImage(self.data, legend='img', + origin=img_origin, + scale=img_scale) + roi = RectangleROI() + roi.setGeometry(origin=roi_origin, size=(20, 20)) + context = stats._ImageContext( + item=self.plot.getImage('img'), + plot=self.plot, + onlimits=False, + roi=roi) + x_start = int((roi_origin[0] - img_origin[0]) / img_scale) + x_end = int(x_start + (20 / img_scale)) + 1 + y_start = int((roi_origin[1] - img_origin[1])/ img_scale) + y_end = int(y_start + (20 / img_scale)) + 1 + x_start = max(x_start, 0) + x_end = min(max(x_end, 0), self.data_dims[1]) + y_start = max(y_start, 0) + y_end = min(max(y_end, 0), self.data_dims[0]) + th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end]) + self.assertAlmostEqual(_stats['sum'].calculate(context), + th_sum) + def suite(): test_suite = unittest.TestSuite() for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, - TestStatsFormatter, TestEmptyStatsWidget, - TestLineWidget, TestUpdateModeWidget): + TestStatsFormatter, TestEmptyStatsWidget, TestStatsROI, + TestLineWidget, TestUpdateModeWidget, ): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite diff --git a/silx/gui/plot/tools/profile/manager.py b/silx/gui/plot/tools/profile/manager.py index 4d467f0..757b741 100644 --- a/silx/gui/plot/tools/profile/manager.py +++ b/silx/gui/plot/tools/profile/manager.py @@ -76,6 +76,17 @@ class _RunnableComputeProfile(qt.QRunnable): self._signals.moveToThread(threadPool.thread()) self._item = item self._roi = roi + self._cancelled = False + + def _lazyCancel(self): + """Cancel the runner if it is not yet started. + + The threadpool will still execute the runner, but this will process + nothing. + + This is only used with Qt<5.9 where QThreadPool.tryTake is not available. + """ + self._cancelled = True def autoDelete(self): return False @@ -106,12 +117,13 @@ class _RunnableComputeProfile(qt.QRunnable): def run(self): """Process the profile computation. """ - try: - profileData = self._roi.computeProfile(self._item) - except Exception: - _logger.error("Error while computing profile", exc_info=True) - else: - self.resultReady.emit(self._roi, profileData) + if not self._cancelled: + try: + profileData = self._roi.computeProfile(self._item) + except Exception: + _logger.error("Error while computing profile", exc_info=True) + else: + self.resultReady.emit(self._roi, profileData) self.runnerFinished.emit(self) @@ -815,8 +827,11 @@ class ProfileManager(qt.QObject): self._pendingRunners.remove(runner) continue if runner.getRoi() is profileRoi: - if threadPool.tryTake(runner): - self._pendingRunners.remove(runner) + if hasattr(threadPool, "tryTake"): + if threadPool.tryTake(runner): + self._pendingRunners.remove(runner) + else: # Support Qt<5.9 + runner._lazyCancel() item = self.getPlotItem() if item is None or not isinstance(item, profileRoi.ITEM_KIND): diff --git a/silx/gui/plot/tools/profile/rois.py b/silx/gui/plot/tools/profile/rois.py index b49679c..9e651a7 100644 --- a/silx/gui/plot/tools/profile/rois.py +++ b/silx/gui/plot/tools/profile/rois.py @@ -137,11 +137,11 @@ class _ImageProfileArea(items.Shape): if not isinstance(item, items.ImageBase): raise TypeError("Unexpected class %s" % type(item)) - if isinstance(item, items.ImageData): - currentData = item.getData(copy=False) - elif isinstance(item, items.ImageRgba): + if isinstance(item, items.ImageRgba): rgba = item.getData(copy=False) currentData = rgba[..., 0] + else: + currentData = item.getData(copy=False) roi = self.getParentRoi() origin = item.getOrigin() @@ -310,15 +310,15 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn): method=method) return coords, profile, profileName, xLabel - if isinstance(item, items.ImageData): - currentData = item.getData(copy=False) - elif isinstance(item, items.ImageRgba): + if isinstance(item, items.ImageRgba): rgba = item.getData(copy=False) is_uint8 = rgba.dtype.type == numpy.uint8 # luminosity if is_uint8: - rgba = rgba.astype(numpy.float) + rgba = rgba.astype(numpy.float64) currentData = 0.21 * rgba[..., 0] + 0.72 * rgba[..., 1] + 0.07 * rgba[..., 2] + else: + currentData = item.getData(copy=False) yLabel = "%s" % str(method).capitalize() coords, profile, title, xLabel = createProfile2(currentData) diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py index 431ecb2..4e2d6db 100644 --- a/silx/gui/plot/tools/roi.py +++ b/silx/gui/plot/tools/roi.py @@ -34,10 +34,13 @@ import enum import logging import time import weakref +import functools import numpy from ... import qt, icons +from ...utils import blockSignals +from ...utils import LockReentrant from .. import PlotWidget from ..items import roi as roi_items @@ -163,6 +166,155 @@ class CreateRoiModeAction(qt.QAction): pass +class RoiModeSelector(qt.QWidget): + def __init__(self, parent=None): + super(RoiModeSelector, self).__init__(parent=parent) + self.__roi = None + self.__reentrant = LockReentrant() + + layout = qt.QHBoxLayout(self) + if isinstance(parent, qt.QMenu): + margins = layout.contentsMargins() + layout.setContentsMargins(margins.left(), 0, margins.right(), 0) + else: + layout.setContentsMargins(0, 0, 0, 0) + + self._label = qt.QLabel(self) + self._label.setText("Mode:") + self._label.setToolTip("Select a specific interaction to edit the ROI") + self._combo = qt.QComboBox(self) + self._combo.currentIndexChanged.connect(self._modeSelected) + layout.addWidget(self._label) + layout.addWidget(self._combo) + self._updateAvailableModes() + + def getRoi(self): + """Returns the edited ROI. + + :rtype: roi_items.RegionOfInterest + """ + return self.__roi + + def setRoi(self, roi): + """Returns the edited ROI. + + :rtype: roi_items.RegionOfInterest + """ + if self.__roi is roi: + return + if not isinstance(roi, roi_items.InteractionModeMixIn): + self.__roi = None + self._updateAvailableModes() + return + + if self.__roi is not None: + self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged) + self.__roi = roi + if self.__roi is not None: + self.__roi.sigInteractionModeChanged.connect(self._modeChanged) + self._updateAvailableModes() + + def isEmpty(self): + return not self._label.isVisibleTo(self) + + def _updateAvailableModes(self): + roi = self.getRoi() + if isinstance(roi, roi_items.InteractionModeMixIn): + modes = roi.availableInteractionModes() + else: + modes = [] + if len(modes) <= 1: + self._label.setVisible(False) + self._combo.setVisible(False) + else: + self._label.setVisible(True) + self._combo.setVisible(True) + with blockSignals(self._combo): + self._combo.clear() + for im, m in enumerate(modes): + self._combo.addItem(m.label, m) + self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole) + mode = roi.getInteractionMode() + self._modeChanged(mode) + index = modes.index(mode) + self._combo.setCurrentIndex(index) + + def _modeChanged(self, mode): + """Triggered when the ROI interaction mode was changed externally""" + if self.__reentrant.locked(): + # This event was initialised by the widget + return + roi = self.__roi + modes = roi.availableInteractionModes() + index = modes.index(mode) + with blockSignals(self._combo): + self._combo.setCurrentIndex(index) + + def _modeSelected(self): + """Triggered when the ROI interaction mode was selected in the widget""" + index = self._combo.currentIndex() + if index == -1: + return + roi = self.getRoi() + if roi is not None: + mode = self._combo.itemData(index, qt.Qt.UserRole) + with self.__reentrant: + roi.setInteractionMode(mode) + + +class RoiModeSelectorAction(qt.QWidgetAction): + """Display the selected mode of a ROI and allow to change it""" + + def __init__(self, parent=None): + super(RoiModeSelectorAction, self).__init__(parent) + self.__roiManager = None + + def createWidget(self, parent): + """Inherit the method to create a new widget""" + widget = RoiModeSelector(parent) + manager = self.__roiManager + if manager is not None: + roi = manager.getCurrentRoi() + widget.setRoi(roi) + self.setVisible(not widget.isEmpty()) + return widget + + def deleteWidget(self, widget): + """Inherit the method to delete a widget""" + widget.setRoi(None) + return qt.QWidgetAction.deleteWidget(self, widget) + + def setRoiManager(self, roiManager): + """ + Connect this action to a ROI manager. + + :param RegionOfInterestManager roiManager: A ROI manager + """ + if self.__roiManager is roiManager: + return + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged) + self.__roiManager = roiManager + if self.__roiManager is not None: + self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged) + self.__currentRoiChanged(roiManager.getCurrentRoi()) + + def __currentRoiChanged(self, roi): + """Handle changes of the selected ROI""" + self.setRoi(roi) + + def setRoi(self, roi): + """Set a profile ROI to edit. + + :param ProfileRoiMixIn roi: A profile ROI + """ + widget = None + for widget in self.createdWidgets(): + widget.setRoi(roi) + if widget is not None: + self.setVisible(not widget.isEmpty()) + + class RegionOfInterestManager(qt.QObject): """Class handling ROI interaction on a PlotWidget. @@ -257,6 +409,8 @@ class RegionOfInterestManager(qt.QObject): parent.sigItemRemoved.connect(self._itemRemoved) + parent._sigDefaultContextMenu.connect(self._feedContextMenu) + @classmethod def getSupportedRoiClasses(cls): """Returns the default available ROI classes @@ -400,25 +554,87 @@ class RegionOfInterestManager(qt.QObject): def _plotSignals(self, event): """Handle mouse interaction for ROI addition""" - if event['event'] in ('markerClicked', 'markerMoving'): + clicked = False + roi = None + if event["event"] in ("markerClicked", "markerMoving"): plot = self.parent() - legend = event['label'] + legend = event["label"] marker = plot._getMarker(legend=legend) roi = self.__getRoiFromMarker(marker) - if roi is not None and roi.isSelectable(): - self.setCurrentRoi(roi) - else: - self.setCurrentRoi(None) - elif event['event'] == 'mouseClicked' and event['button'] == 'left': + elif event["event"] == "mouseClicked" and event["button"] == "left": # Marker click is only for dnd # This also can click on a marker + clicked = True plot = self.parent() - marker = plot._getMarkerAt(event['xpixel'], event['ypixel']) + marker = plot._getMarkerAt(event["xpixel"], event["ypixel"]) roi = self.__getRoiFromMarker(marker) - if roi is not None and roi.isSelectable(): + else: + return + + if roi not in self._rois: + # The ROI is not own by this manager + return + + if roi is not None: + currentRoi = self.getCurrentRoi() + if currentRoi is roi: + if clicked: + self.__updateMode(roi) + elif roi.isSelectable(): self.setCurrentRoi(roi) + else: + self.setCurrentRoi(None) + + def __updateMode(self, roi): + if isinstance(roi, roi_items.InteractionModeMixIn): + available = roi.availableInteractionModes() + mode = roi.getInteractionMode() + imode = available.index(mode) + mode = available[(imode + 1) % len(available)] + roi.setInteractionMode(mode) + + def _feedContextMenu(self, menu): + """Called wen the default plot context menu is about to be displayed""" + roi = self.getCurrentRoi() + if roi is not None: + if roi.isEditable(): + # Filter by data position + # FIXME: It would be better to use GUI coords for it + plot = self.parent() + pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()) + data = plot.pixelToData(pos.x(), pos.y()) + if roi.contains(data): + if isinstance(roi, roi_items.InteractionModeMixIn): + self._contextMenuForInteractionMode(menu, roi) + + removeAction = qt.QAction(menu) + removeAction.setText("Remove %s" % roi.getName()) + callback = functools.partial(self.removeRoi, roi) + removeAction.triggered.connect(callback) + menu.addAction(removeAction) + + def _contextMenuForInteractionMode(self, menu, roi): + availableModes = roi.availableInteractionModes() + currentMode = roi.getInteractionMode() + submenu = qt.QMenu(menu) + modeGroup = qt.QActionGroup(menu) + modeGroup.setExclusive(True) + for mode in availableModes: + action = qt.QAction(menu) + action.setText(mode.label) + action.setToolTip(mode.description) + action.setCheckable(True) + if mode is currentMode: + action.setChecked(True) else: - self.setCurrentRoi(None) + callback = functools.partial(roi.setInteractionMode, mode) + action.triggered.connect(callback) + modeGroup.addAction(action) + submenu.addAction(action) + action = qt.QAction(menu) + action.setMenu(submenu) + action.setText("%s interaction mode" % roi.getName()) + menu.addAction(action) # RegionOfInterest API @@ -666,8 +882,9 @@ class RegionOfInterestManager(qt.QObject): if self._drawnROI is not None: # Cancel ROI create - self.removeRoi(self._drawnROI) + roi = self._drawnROI self._drawnROI = None + self.removeRoi(roi) plot = self.parent() if plot is not None: diff --git a/silx/gui/plot/tools/test/testROI.py b/silx/gui/plot/tools/test/testROI.py index 33a0000..8a00073 100644 --- a/silx/gui/plot/tools/test/testROI.py +++ b/silx/gui/plot/tools/test/testROI.py @@ -136,6 +136,31 @@ class TestRoiItems(TestCaseQt): numpy.testing.assert_allclose(item.getCenter(), center) numpy.testing.assert_allclose(item.getRadius(), newRadius) + def testCircle_contains(self): + center = numpy.array([2, -1]) + radius = 1. + item = roi_items.CircleROI() + item.setGeometry(center=center, radius=radius) + self.assertTrue(item.contains([1, -1])) + self.assertFalse(item.contains([0, 0])) + self.assertTrue(item.contains([2, 0])) + self.assertFalse(item.contains([3.01, -1])) + + def testEllipse_contains(self): + center = numpy.array([-2, 0]) + item = roi_items.EllipseROI() + item.setCenter(center) + item.setOrientation(numpy.pi / 4.0) + item.setMajorRadius(2) + item.setMinorRadius(1) + print(item.getMinorRadius(), item.getMajorRadius()) + self.assertFalse(item.contains([0, 0])) + self.assertTrue(item.contains([-1, 1])) + self.assertTrue(item.contains([-3, 0])) + self.assertTrue(item.contains([-2, 0])) + self.assertTrue(item.contains([-2, 1])) + self.assertFalse(item.contains([-4, 1])) + def testRectangle_isIn(self): origin = numpy.array([0, 0]) size = numpy.array([10, 20]) @@ -557,8 +582,9 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase): mx, my = self.plot.dataToPixel(*center) self.mouseMove(widget, pos=(mx, my)) self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.mouseMove(widget, pos=(mx, my+25)) self.mouseMove(widget, pos=(mx, my+50)) - self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50)) result = numpy.array(item.getEndPoints()) # x location is still the same @@ -615,6 +641,45 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase): # Clean up manager.clear() + def testArcRoiSwitchMode(self): + """Make sure we can switch mode by clicking on the ROI""" + xlimit = self.plot.getXAxis().getLimits() + ylimit = self.plot.getYAxis().getLimits() + points = numpy.array([xlimit, ylimit]).T + center = numpy.mean(points, axis=0) + size = numpy.abs(points[1] - points[0]) + + # Create the line + manager = roi.RegionOfInterestManager(self.plot) + item = roi_items.ArcROI() + item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3) + item.setEditable(True) + item.setSelectable(True) + manager.addRoi(item) + self.qapp.processEvents() + + # Initial state + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode) + self.qWait(500) + + # Click on the center + widget = self.plot.getWidgetHandle() + mx, my = self.plot.dataToPixel(*center) + + # Select the ROI + self.mouseMove(widget, pos=(mx, my)) + self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.qWait(500) + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode) + + # Change the mode + self.mouseMove(widget, pos=(mx, my)) + self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my)) + self.qWait(500) + self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode) + + manager.clear() + self.qapp.processEvents() def suite(): diff --git a/silx/gui/plot3d/ScalarFieldView.py b/silx/gui/plot3d/ScalarFieldView.py index 50cba05..b2bb254 100644 --- a/silx/gui/plot3d/ScalarFieldView.py +++ b/silx/gui/plot3d/ScalarFieldView.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -239,7 +239,7 @@ class SelectedRegion(object): def __init__(self, arrayRange, dataBBox, translation=(0., 0., 0.), scale=(1., 1., 1.)): - self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int) + self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int64) assert self._arrayRange.shape == (3, 2) assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0]) @@ -1449,7 +1449,7 @@ class ScalarFieldView(Plot3DWindow): min(self._data.shape[1], max(*yrange))), (max(0, min(*xrange_)), min(self._data.shape[2], max(*xrange_))), - ), dtype=numpy.int) + ), dtype=numpy.int64) # numpy.equal supports None if not numpy.all(numpy.equal(selectedRange, self._selectedRange)): diff --git a/silx/gui/plot3d/items/_pick.py b/silx/gui/plot3d/items/_pick.py index 8494723..0d6a495 100644 --- a/silx/gui/plot3d/items/_pick.py +++ b/silx/gui/plot3d/items/_pick.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2018-2019 European Synchrotron Radiation Facility +# Copyright (c) 2018-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -197,7 +197,7 @@ class PickingResult(_PickingResult): super(PickingResult, self).__init__(item, indices) self._objectPositions = numpy.array( - positions, copy=False, dtype=numpy.float) + positions, copy=False, dtype=numpy.float64) # Store matrices to generate positions on demand primitive = item._getScenePrimitive() diff --git a/silx/gui/plot3d/items/core.py b/silx/gui/plot3d/items/core.py index 1745b2b..ab2ceb6 100644 --- a/silx/gui/plot3d/items/core.py +++ b/silx/gui/plot3d/items/core.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2017-2018 European Synchrotron Radiation Facility +# Copyright (c) 2017-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -400,32 +400,32 @@ class DataItem3D(Item3D): self._updated(Item3DChangedType.TRANSFORM) def setRotationCenter(self, x=0., y=0., z=0.): - """Set the center of rotation of the item. - - Position of the rotation center is either a float - for an absolute position or one of the following - string to define a position relative to the item's bounding box: - 'lower', 'center', 'upper' - - :param x: rotation center position on the X axis - :rtype: float or str - :param y: rotation center position on the Y axis - :rtype: float or str - :param z: rotation center position on the Z axis - :rtype: float or str - """ - center = [] - for position in (x, y, z): - if isinstance(position, six.string_types): - assert position in self._ROTATION_CENTER_TAGS - else: - position = float(position) - center.append(position) - center = tuple(center) - - if center != self._rotationCenter: - self._rotationCenter = center - self._updateRotationCenter() + """Set the center of rotation of the item. + + Position of the rotation center is either a float + for an absolute position or one of the following + string to define a position relative to the item's bounding box: + 'lower', 'center', 'upper' + + :param x: rotation center position on the X axis + :rtype: float or str + :param y: rotation center position on the Y axis + :rtype: float or str + :param z: rotation center position on the Z axis + :rtype: float or str + """ + center = [] + for position in (x, y, z): + if isinstance(position, six.string_types): + assert position in self._ROTATION_CENTER_TAGS + else: + position = float(position) + center.append(position) + center = tuple(center) + + if center != self._rotationCenter: + self._rotationCenter = center + self._updateRotationCenter() def getRotationCenter(self): """Returns the rotation center set by :meth:`setRotationCenter`. diff --git a/silx/gui/plot3d/items/mixins.py b/silx/gui/plot3d/items/mixins.py index 14cafc8..f512365 100644 --- a/silx/gui/plot3d/items/mixins.py +++ b/silx/gui/plot3d/items/mixins.py @@ -141,6 +141,7 @@ class ColormapMixIn(_ColormapMixIn): self.__sceneColormap.norm = colormap.getNormalization() self.__sceneColormap.gamma = colormap.getGammaNormalizationParameter() self.__sceneColormap.range_ = colormap.getColormapRange(self) + self.__sceneColormap.nancolor = rgba(colormap.getNaNColor()) class ComplexMixIn(_ComplexMixIn): diff --git a/silx/gui/plot3d/items/volume.py b/silx/gui/plot3d/items/volume.py index 6c6562f..f80fea2 100644 --- a/silx/gui/plot3d/items/volume.py +++ b/silx/gui/plot3d/items/volume.py @@ -444,7 +444,7 @@ class Isosurface(Item3D): return None # No intersected triangles intersections = numpy.array(intersections)[numpy.argsort(depths)] - indices = numpy.transpose(numpy.round(intersections).astype(numpy.int)) + indices = numpy.transpose(numpy.round(intersections).astype(numpy.int64)) return PickingResult(self, positions=intersections, indices=indices) diff --git a/silx/gui/plot3d/scene/cutplane.py b/silx/gui/plot3d/scene/cutplane.py index 81c74c7..88147df 100644 --- a/silx/gui/plot3d/scene/cutplane.py +++ b/silx/gui/plot3d/scene/cutplane.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2019 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -88,7 +88,7 @@ class ColormapMesh3D(Geometry): float value = texture3D(data, vTexCoords).r; vec4 color = $colormapCall(value); - color.a = alpha; + color.a *= alpha; gl_FragColor = $lightingCall(color, vPosition, vNormal); diff --git a/silx/gui/plot3d/scene/function.py b/silx/gui/plot3d/scene/function.py index 69a24dd..2deb785 100644 --- a/silx/gui/plot3d/scene/function.py +++ b/silx/gui/plot3d/scene/function.py @@ -389,10 +389,13 @@ class Colormap(event.Notifier, ProgramFunction): uniform float cmap_parameter; uniform float cmap_min; uniform float cmap_oneOverRange; + uniform vec4 nancolor; const float oneOverLog10 = 0.43429448190325176; vec4 colormap(float value) { + float data = value; /* Keep original input value for isnan test */ + if (cmap_normalization == 1) { /* Log10 mapping */ if (value > 0.0) { value = clamp(cmap_oneOverRange * @@ -421,7 +424,12 @@ class Colormap(event.Notifier, ProgramFunction): $discard - vec4 color = texture2D(cmap_texture, vec2(value, 0.5)); + vec4 color; + if (data != data) { /* isnan alternative for compatibility with GLSL 1.20 */ + color = nancolor; + } else { + color = texture2D(cmap_texture, vec2(value, 0.5)); + } return color; } """) @@ -458,9 +466,10 @@ class Colormap(event.Notifier, ProgramFunction): self._gamma = -1. self._range = 1., 10. self._displayValuesBelowMin = True + self._nancolor = numpy.array((1., 1., 1., 0.), dtype=numpy.float32) self._texture = None - self._update_texture = True + self._textureToDiscard = None if colormap is None: # default colormap @@ -468,7 +477,7 @@ class Colormap(event.Notifier, ProgramFunction): colormap[:] = numpy.arange(256, dtype=numpy.uint8)[:, numpy.newaxis] - # Set to param values through properties to go through asserts + # Set to values through properties to perform asserts and updates self.colormap = colormap self.norm = norm self.gamma = gamma @@ -491,9 +500,40 @@ class Colormap(event.Notifier, ProgramFunction): assert colormap.ndim == 2 assert colormap.shape[1] in (3, 4) self._colormap = colormap - self._update_texture = True + + if self._texture is not None and self._texture.name is not None: + self._textureToDiscard = self._texture + + data = numpy.empty( + (16, self._colormap.shape[0], self._colormap.shape[1]), + dtype=self._colormap.dtype) + data[:] = self._colormap + + format_ = gl.GL_RGBA if data.shape[-1] == 4 else gl.GL_RGB + + self._texture = _glutils.Texture( + format_, data, format_, + texUnit=self._COLORMAP_TEXTURE_UNIT, + minFilter=gl.GL_NEAREST, + magFilter=gl.GL_NEAREST, + wrap=gl.GL_CLAMP_TO_EDGE) + self.notify() + @property + def nancolor(self): + """RGBA color to use for Not-A-Number values as 4 float in [0., 1.]""" + return self._nancolor + + @nancolor.setter + def nancolor(self, color): + color = numpy.clip(numpy.array(color, dtype=numpy.float32), 0., 1.) + assert color.ndim == 1 + assert len(color) == 4 + if not numpy.array_equal(self._nancolor, color): + self._nancolor = color + self.notify() + @property def norm(self): """Normalization to use for colormap mapping. @@ -576,9 +616,6 @@ class Colormap(event.Notifier, ProgramFunction): """ self.prepareGL2(context) # TODO see how to handle - if self._texture is None: # No colormap - return - self._texture.bind() gl.glUniform1i(program.uniforms['cmap_texture'], @@ -607,23 +644,11 @@ class Colormap(event.Notifier, ProgramFunction): gl.glUniform1f(program.uniforms['cmap_min'], min_) gl.glUniform1f(program.uniforms['cmap_oneOverRange'], (1. / (max_ - min_)) if max_ != min_ else 0.) + gl.glUniform4f(program.uniforms['nancolor'], *self._nancolor) def prepareGL2(self, context): - if self._texture is None or self._update_texture: - if self._texture is not None: - self._texture.discard() - - colormap = numpy.empty( - (16, self._colormap.shape[0], self._colormap.shape[1]), - dtype=self._colormap.dtype) - colormap[:] = self._colormap - - format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB - - self._texture = _glutils.Texture( - format_, colormap, format_, - texUnit=self._COLORMAP_TEXTURE_UNIT, - minFilter=gl.GL_NEAREST, - magFilter=gl.GL_NEAREST, - wrap=gl.GL_CLAMP_TO_EDGE) - self._update_texture = False + if self._textureToDiscard is not None: + self._textureToDiscard.discard() + self._textureToDiscard = None + + self._texture.prepare() diff --git a/silx/gui/plot3d/scene/primitives.py b/silx/gui/plot3d/scene/primitives.py index 7db61e8..b4c8e26 100644 --- a/silx/gui/plot3d/scene/primitives.py +++ b/silx/gui/plot3d/scene/primitives.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2019 European Synchrotron Radiation Facility +# Copyright (c) 2015-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -49,7 +49,7 @@ from . import event from . import core from . import transform from . import utils -from .function import Colormap, Fog +from .function import Colormap _logger = logging.getLogger(__name__) @@ -367,7 +367,7 @@ class Geometry(core.Elem): min_ = numpy.nanmin(attribute, axis=0) max_ = numpy.nanmax(attribute, axis=0) else: - min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32) + min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32) toCopy = min(len(min_), 3-index) if toCopy != len(min_): @@ -2077,7 +2077,7 @@ class _Image(Geometry): self._update_texture = True # By updating the position rather than always using a unit square # we benefit from Geometry bounds handling - self.setAttribute('position', self._UNIT_SQUARE * self._data.shape[:2]) + self.setAttribute('position', self._UNIT_SQUARE * (self._data.shape[1], self._data.shape[0])) self.notify() def getData(self, copy=True): @@ -2188,7 +2188,7 @@ class _Image(Geometry): gl.glUniform1f(program.uniforms['alpha'], self._alpha) shape = self._data.shape - gl.glUniform2f(program.uniforms['dataScale'], 1./shape[0], 1./shape[1]) + gl.glUniform2f(program.uniforms['dataScale'], 1./shape[1], 1./shape[0]) gl.glUniform1i(program.uniforms['data'], self._texture.texUnit) diff --git a/silx/gui/plot3d/scene/text.py b/silx/gui/plot3d/scene/text.py index c2983d5..bacc2e6 100644 --- a/silx/gui/plot3d/scene/text.py +++ b/silx/gui/plot3d/scene/text.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -251,6 +251,7 @@ class Text2D(primitives.Geometry): minFilter=gl.GL_NEAREST, magFilter=gl.GL_NEAREST, wrap=gl.GL_CLAMP_TO_EDGE) + self._texture.prepare() self._dirtyAlign = True # To force update of offset if self._dirtyAlign: diff --git a/silx/gui/plot3d/scene/transform.py b/silx/gui/plot3d/scene/transform.py index 1b82397..43b739b 100644 --- a/silx/gui/plot3d/scene/transform.py +++ b/silx/gui/plot3d/scene/transform.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2018 European Synchrotron Radiation Facility +# Copyright (c) 2015-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -855,13 +855,13 @@ class _Projection(Transform): class Orthographic(_Projection): - """Orthographic (i.e., parallel) projection which keeps aspect ratio. + """Orthographic (i.e., parallel) projection which can keep aspect ratio. Clipping planes are adjusted to match the aspect ratio of - the :attr:`size` attribute. + the :attr:`size` attribute if :attr:`keepaspect` is True. - The left, right, bottom and top parameters defines the area which must - always remain visible. + In this case, the left, right, bottom and top parameters defines the area + which must always remain visible. Effective clipping planes are adjusted to keep the aspect ratio. :param float left: Coord of the left clipping plane. @@ -873,12 +873,15 @@ class Orthographic(_Projection): :param size: Viewport's size used to compute the aspect ratio (width, height). :type size: 2-tuple of float + :param bool keepaspect: + True (default) to keep aspect ratio, False otherwise. """ def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1., - size=(1., 1.)): + size=(1., 1.), keepaspect=True): self._left, self._right = left, right self._bottom, self._top = bottom, top + self._keepaspect = bool(keepaspect) super(Orthographic, self).__init__(near, far, checkDepthExtent=False, size=size) # _update called when setting size @@ -888,22 +891,23 @@ class Orthographic(_Projection): self.left, self.right, self.bottom, self.top, self.near, self.far) def _update(self, left, right, bottom, top): - width, height = self.size - aspect = width / height + if self.keepaspect: + width, height = self.size + aspect = width / height - orthoaspect = abs(left - right) / abs(bottom - top) + orthoaspect = abs(left - right) / abs(bottom - top) - if orthoaspect >= aspect: # Keep width, enlarge height - newheight = \ - numpy.sign(top - bottom) * abs(left - right) / aspect - bottom = 0.5 * (bottom + top) - 0.5 * newheight - top = bottom + newheight + if orthoaspect >= aspect: # Keep width, enlarge height + newheight = \ + numpy.sign(top - bottom) * abs(left - right) / aspect + bottom = 0.5 * (bottom + top) - 0.5 * newheight + top = bottom + newheight - else: # Keep height, enlarge width - newwidth = \ - numpy.sign(right - left) * abs(bottom - top) * aspect - left = 0.5 * (left + right) - 0.5 * newwidth - right = left + newwidth + else: # Keep height, enlarge width + newwidth = \ + numpy.sign(right - left) * abs(bottom - top) * aspect + left = 0.5 * (left + right) - 0.5 * newwidth + right = left + newwidth # Store values self._left, self._right = left, right @@ -942,15 +946,30 @@ class Orthographic(_Projection): @property def size(self): - """Viewport size as a 2-tuple of float (width, height) or None.""" + """Viewport size as a 2-tuple of float (width, height)""" return self._size @size.setter def size(self, size): assert len(size) == 2 - self._size = float(size[0]), float(size[1]) - self._update(self.left, self.right, self.bottom, self.top) - self.notify() + size = float(size[0]), float(size[1]) + if size != self._size: + self._size = size + self._update(self.left, self.right, self.bottom, self.top) + self.notify() + + @property + def keepaspect(self): + """True to keep aspect ratio, False otherwise.""" + return self._keepaspect + + @keepaspect.setter + def keepaspect(self, aspect): + aspect = bool(aspect) + if aspect != self._keepaspect: + self._keepaspect = aspect + self._update(self.left, self.right, self.bottom, self.top) + self.notify() class Ortho2DWidget(_Projection): diff --git a/silx/gui/plot3d/scene/utils.py b/silx/gui/plot3d/scene/utils.py index bddbcac..c6cd129 100644 --- a/silx/gui/plot3d/scene/utils.py +++ b/silx/gui/plot3d/scene/utils.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2015-2019 European Synchrotron Radiation Facility +# Copyright (c) 2015-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -540,7 +540,7 @@ def segmentVolumeIntersect(segment, nbins): # bin edges/line intersection points points = t.reshape(-1, 1) * delta + p0 centers = (points[:-1] + points[1:]) / 2. - bins = numpy.floor(centers).astype(numpy.int) + bins = numpy.floor(centers).astype(numpy.int64) return bins diff --git a/silx/gui/plot3d/test/testStatsWidget.py b/silx/gui/plot3d/test/testStatsWidget.py index 1157aec..bcab1a4 100644 --- a/silx/gui/plot3d/test/testStatsWidget.py +++ b/silx/gui/plot3d/test/testStatsWidget.py @@ -34,6 +34,7 @@ import numpy from silx.utils.testutils import ParametricTestCase from silx.gui.utils.testutils import TestCaseQt +from silx.gui.plot.stats.stats import Stats from silx.gui import qt from silx.gui.plot.StatsWidget import BasicStatsWidget @@ -55,6 +56,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase): # self.qWaitForWindowExposed(self.sceneWidget) def tearDown(self): + Stats._getContext.cache_clear() self.qapp.processEvents() self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose) self.sceneWidget.close() @@ -147,6 +149,7 @@ class TestScalarFieldView(TestCaseQt): # self.qWaitForWindowExposed(self.sceneWidget) def tearDown(self): + Stats._getContext.cache_clear() self.qapp.processEvents() self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose) self.scalarFieldView.close() diff --git a/silx/gui/test/test_colors.py b/silx/gui/test/test_colors.py index f83ff58..9e23a93 100755 --- a/silx/gui/test/test_colors.py +++ b/silx/gui/test/test_colors.py @@ -113,6 +113,20 @@ class TestApplyColormapToData(ParametricTestCase): self.assertEqual(len(value), 1) self.assertEqual(value[0, 0], 128) + def testNaNColor(self): + """Test Colormap.applyToData with NaN values""" + colormap = Colormap(name='gray', normalization='linear') + colormap.setNaNColor('red') + self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0)) + + data = numpy.array([50., numpy.nan]) + image = items.ImageData() + image.setData(numpy.array([[0, 100]])) + value = colormap.applyToData(data, reference=image) + self.assertEqual(len(value), 2) + self.assertTrue(numpy.array_equal(value[0], (128, 128, 128, 255))) + self.assertTrue(numpy.array_equal(value[1], (255, 0, 0, 255))) + class TestDictAPI(unittest.TestCase): """Make sure the old dictionary API is working @@ -436,9 +450,10 @@ class TestObjectAPI(ParametricTestCase): Colormap(name="viridis"), Colormap(normalization=Colormap.SQRT) ] - gamma = Colormap(normalization=Colormap.GAMMA) - gamma.setGammaNormalizationParameter(1.2) - colormaps.append(gamma) + cmap = Colormap(normalization=Colormap.GAMMA) + cmap.setGammaNormalizationParameter(1.2) + cmap.setNaNColor('red') + colormaps.append(cmap) for expected in colormaps: with self.subTest(colormap=expected): state = expected.saveState() @@ -459,6 +474,21 @@ class TestObjectAPI(ParametricTestCase): expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM) self.assertEqual(colormap, expected) + def testStorageV2(self): + state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\ + b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\ + b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\ + b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\ + b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x' + state = qt.QByteArray(state) + colormap = Colormap() + colormap.restoreState(state) + + expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM) + expected.setGammaNormalizationParameter(1.5) + self.assertEqual(colormap, expected) + + class TestPreferredColormaps(unittest.TestCase): """Test get|setPreferredColormaps functions""" @@ -540,20 +570,25 @@ class TestAutoscaleRange(ParametricTestCase): def testAutoscaleRange(self): nan = numpy.nan + data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]) + data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan]) data = [ # Positive values (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)), (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)), - (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (-80, 190)), - (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (1, 1000)), + (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)), + (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)), + (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)), + (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)), + # With nan (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)), (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)), - (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100, nan]), (-80, 190)), - (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, nan]), (1, 1000)), + (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)), + (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)), # With negative (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)), - (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (1, 1000)), + (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)), ] for norm, mode, array, expectedRange in data: with self.subTest(norm=norm, mode=mode, array=array): diff --git a/silx/gui/utils/glutils.py b/silx/gui/utils/glutils.py index fca9a32..83cfd89 100644 --- a/silx/gui/utils/glutils.py +++ b/silx/gui/utils/glutils.py @@ -27,6 +27,13 @@ import os import sys + +if __name__ == "__main__": + # When run as a script, remove directory from sys.path + # This avoids other script in same directory to override Python modules + if os.path.abspath(sys.path[0]) == os.path.abspath(os.path.dirname(__file__)): + sys.path.pop(0) + import subprocess from silx.gui import qt diff --git a/silx/gui/utils/matplotlib.py b/silx/gui/utils/matplotlib.py new file mode 100644 index 0000000..484e01a --- /dev/null +++ b/silx/gui/utils/matplotlib.py @@ -0,0 +1,71 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +from __future__ import absolute_import + +"""This module initializes matplotlib and sets-up the backend to use. + +It MUST be imported prior to any other import of matplotlib. + +It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding +to the used backend. +""" + +__authors__ = ["T. Vincent"] +__license__ = "MIT" +__date__ = "02/05/2018" + + +from pkg_resources import parse_version +import matplotlib + +from .. import qt + + +def _matplotlib_use(backend, force): + """Wrapper of `matplotlib.use` to set-up backend. + + It adds extra initialization for PySide and PySide2 with matplotlib < 2.2. + """ + # This is kept for compatibility with matplotlib < 2.2 + if parse_version(matplotlib.__version__) < parse_version('2.2'): + if qt.BINDING == 'PySide': + matplotlib.rcParams['backend.qt4'] = 'PySide' + if qt.BINDING == 'PySide2': + matplotlib.rcParams['backend.qt5'] = 'PySide2' + + matplotlib.use(backend, force=force) + + +if qt.BINDING in ('PyQt4', 'PySide'): + _matplotlib_use('Qt4Agg', force=False) + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa + +elif qt.BINDING in ('PyQt5', 'PySide2'): + _matplotlib_use('Qt5Agg', force=False) + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa + +else: + raise ImportError("Unsupported Qt binding: %s" % qt.BINDING) diff --git a/silx/gui/utils/signal.py b/silx/gui/utils/signal.py new file mode 100644 index 0000000..359f5cc --- /dev/null +++ b/silx/gui/utils/signal.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2012 University of North Carolina at Chapel Hill, Luke Campagnola +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""This module contains utils relative to qt Signal +""" + +from silx.gui import qt +import weakref +from time import time +from silx.gui.utils import concurrent + +__all__ = ['SignalProxy'] +__authors__ = ['L. Campagnola', 'M. Liberty'] +__license__ = "MIT" + + +class SignalProxy(qt.QObject): + """ + This peace of code come from pyqtgraph + Object which collects rapid-fire signals and condenses them + into a single signal or a rate-limited stream of signals. + Used, for example, to prevent a SpinBox from generating multiple + signals when the mouse wheel is rolled over it. + + Emits sigDelayed after input signals have stopped for a certain period of time. + """ + + sigDelayed = qt.Signal(object) + + def __init__(self, signal, delay=0.3, rateLimit=0, slot=None): + """Initialization arguments: + signal - a bound Signal or pyqtSignal instance + delay - Time (in seconds) to wait for signals to stop before emitting (default 0.3s) + slot - Optional function to connect sigDelayed to. + rateLimit - (signals/sec) if greater than 0, this allows signals to stream out at a + steady rate while they are being received. + """ + + qt.QObject.__init__(self) + signal.connect(self.signalReceived) + self.signal = signal + self.delay = delay + self.rateLimit = rateLimit + self.args = None + self.timer = qt.QTimer() + self.timer.timeout.connect(self.flush) + self.blockSignal = False + self.slot = weakref.ref(slot) + self.lastFlushTime = None + if slot is not None: + self.sigDelayed.connect(slot) + + def setDelay(self, delay): + self.delay = delay + + def signalReceived(self, *args): + """Received signal. Cancel previous timer and store args to be forwarded later.""" + if self.blockSignal: + return + self.args = args + if self.rateLimit == 0: + concurrent.submitToQtMainThread(self.timer.stop) + concurrent.submitToQtMainThread(self.timer.start, (self.delay * 1000) + 1) + else: + now = time() + if self.lastFlushTime is None: + leakTime = 0 + else: + lastFlush = self.lastFlushTime + leakTime = max(0, (lastFlush + (1.0 / self.rateLimit)) - now) + + concurrent.submitToQtMainThread(self.timer.stop) + concurrent.submitToQtMainThread(self.timer.start, (min(leakTime, self.delay) * 1000) + 1) + # self.timer.stop() + # self.timer.start((min(leakTime, self.delay) * 1000) + 1) + + def flush(self): + """If there is a signal queued up, send it now.""" + if self.args is None or self.blockSignal: + return False + args, self.args = self.args, None + concurrent.submitToQtMainThread(self.timer.stop) + self.lastFlushTime = time() + # self.emit(self.signal, *self.args) + concurrent.submitToQtMainThread(self.sigDelayed.emit, args) + # self.sigDelayed.emit(args) + return True + + def disconnect(self): + self.blockSignal = True + try: + self.signal.disconnect(self.signalReceived) + except: + pass + try: + self.sigDelayed.disconnect(self.slot) + except: + pass + + +if __name__ == '__main__': + app = qt.QApplication([]) + win = qt.QMainWindow() + spin = qt.QSpinBox() + win.setCentralWidget(spin) + win.show() + + + def fn(*args): + print("Raw signal:", args) + + + def fn2(*args): + print("Delayed signal:", args) + + + spin.valueChanged.connect(fn) + # proxy = proxyConnect(spin, QtCore.SIGNAL('valueChanged(int)'), fn) + proxy = SignalProxy(spin.valueChanged, delay=0.5, slot=fn2) diff --git a/silx/gui/utils/testutils.py b/silx/gui/utils/testutils.py index c086657..30b9e34 100644 --- a/silx/gui/utils/testutils.py +++ b/silx/gui/utils/testutils.py @@ -142,8 +142,6 @@ class TestCaseQt(unittest.TestCase): @classmethod def tearDownClass(cls): sys.excepthook = cls._oldExceptionHook - if cls._qapp is not None: - cls._qapp = None def setUp(self): """Get the list of existing widgets.""" diff --git a/silx/gui/widgets/ElidedLabel.py b/silx/gui/widgets/ElidedLabel.py index 58513c7..fe53bb9 100644 --- a/silx/gui/widgets/ElidedLabel.py +++ b/silx/gui/widgets/ElidedLabel.py @@ -61,12 +61,12 @@ class ElidedLabel(qt.QLabel): self.__updateText() def __updateMinimumSize(self): - metrics = qt.QFontMetrics(self.font()) + metrics = self.fontMetrics() width = metrics.width("...") self.setMinimumWidth(width) def __updateText(self): - metrics = qt.QFontMetrics(self.font()) + metrics = self.fontMetrics() elidedText = metrics.elidedText(self.__text, self.__elideMode, self.width()) qt.QLabel.setText(self, elidedText) wasElided = self.__textIsElided diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py index b868171..9aaec76 100644 --- a/silx/gui/widgets/test/__init__.py +++ b/silx/gui/widgets/test/__init__.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016-2018 European Synchrotron Radiation Facility +# Copyright (c) 2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,6 +34,7 @@ from . import test_boxlayoutdockwidget from . import test_rangeslider from . import test_flowlayout from . import test_elidedlabel +from . import test_legendiconwidget __authors__ = ["V. Valls", "P. Knobel"] __license__ = "MIT" @@ -53,5 +54,6 @@ def suite(): test_rangeslider.suite(), test_flowlayout.suite(), test_elidedlabel.suite(), + test_legendiconwidget.suite(), ]) return test_suite diff --git a/silx/gui/widgets/test/test_legendiconwidget.py b/silx/gui/widgets/test/test_legendiconwidget.py new file mode 100644 index 0000000..f845f75 --- /dev/null +++ b/silx/gui/widgets/test/test_legendiconwidget.py @@ -0,0 +1,74 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2020 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ +"""Tests for LegendIconWidget""" + +__authors__ = ["V. Valls"] +__license__ = "MIT" +__date__ = "23/10/2020" + +import unittest + +from silx.gui import qt +from silx.gui.widgets.LegendIconWidget import LegendIconWidget +from silx.gui.utils.testutils import TestCaseQt +from silx.utils.testutils import ParametricTestCase + + +class TestLegendIconWidget(TestCaseQt, ParametricTestCase): + """Tests for TestRangeSlider""" + + def setUp(self): + self.widget = LegendIconWidget() + self.widget.show() + self.qWaitForWindowExposed(self.widget) + + def tearDown(self): + self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) + self.widget.close() + del self.widget + self.qapp.processEvents() + + def testCreate(self): + self.qapp.processEvents() + + def testColormap(self): + self.widget.setColormap("viridis") + self.qapp.processEvents() + + def testSymbol(self): + self.widget.setSymbol("o") + self.widget.setSymbolColormap("viridis") + self.qapp.processEvents() + + +def suite(): + loader = unittest.defaultTestLoader.loadTestsFromTestCase + test_suite = unittest.TestSuite() + test_suite.addTest(loader(TestLegendIconWidget)) + return test_suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') diff --git a/silx/image/marchingsquares/_mergeimpl.pyx b/silx/image/marchingsquares/_mergeimpl.pyx index 7286a66..5a7a3b5 100644 --- a/silx/image/marchingsquares/_mergeimpl.pyx +++ b/silx/image/marchingsquares/_mergeimpl.pyx @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2018 European Synchrotron Radiation Facility +# Copyright (C) 2018-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -48,7 +48,7 @@ cimport libc.string cimport cython -include "../../utils/_have_openmp.pxi" +from ...utils._have_openmp cimport COMPILED_WITH_OPENMP """Store in the module if it was compiled with OpenMP""" cdef double EPSILON = numpy.finfo(numpy.float64).eps diff --git a/silx/image/tomography.py b/silx/image/tomography.py index c2aedd8..53855c1 100644 --- a/silx/image/tomography.py +++ b/silx/image/tomography.py @@ -32,6 +32,7 @@ __date__ = "12/09/2017" import numpy as np from math import pi +from functools import lru_cache from itertools import product from bisect import bisect from silx.math.fit import leastsq @@ -128,6 +129,7 @@ def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.): return filt_f +@lru_cache(maxsize=1) def generate_powers(): """ Generate a list of powers of [2, 3, 5, 7], diff --git a/silx/io/commonh5.py b/silx/io/commonh5.py index b624816..57232d8 100644 --- a/silx/io/commonh5.py +++ b/silx/io/commonh5.py @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2016-2019 European Synchrotron Radiation Facility +# Copyright (C) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -376,6 +376,24 @@ class Dataset(Node): There is no chunks.""" return None + @property + def is_virtual(self): + """Checks virtual data as provided by `h5py.Dataset`""" + return False + + def virtual_sources(self): + """Returns virtual dataset sources as provided by `h5py.Dataset`. + + :rtype: list""" + raise RuntimeError("Not a virtual dataset") + + @property + def external(self): + """Returns external sources as provided by `h5py.Dataset`. + + :rtype: list or None""" + return None + def __array__(self, dtype=None): # Special case for (0,)*-shape datasets if numpy.product(self.shape) == 0: @@ -958,7 +976,7 @@ class Group(Node): raise TypeError("Path are not supported") if data is None: if dtype is None: - dtype = numpy.float + dtype = numpy.float64 data = numpy.empty(shape=shape, dtype=dtype) elif dtype is not None: data = data.astype(dtype) diff --git a/silx/io/dictdump.py b/silx/io/dictdump.py index f2318e0..bbb244a 100644 --- a/silx/io/dictdump.py +++ b/silx/io/dictdump.py @@ -34,9 +34,11 @@ import sys import h5py from .configdict import ConfigDict -from .utils import is_group +from .utils import is_group, is_link, is_softlink, is_externallink from .utils import is_file as is_h5_file_like from .utils import open as h5open +from .utils import h5py_read_dataset +from .utils import H5pyAttributesReadWrapper __authors__ = ["P. Knobel"] __license__ = "MIT" @@ -44,35 +46,24 @@ __date__ = "17/07/2018" logger = logging.getLogger(__name__) -string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa +vlen_utf8 = h5py.special_dtype(vlen=str) +vlen_bytes = h5py.special_dtype(vlen=bytes) -def _prepare_hdf5_dataset(array_like): +def _prepare_hdf5_write_value(array_like): """Cast a python object into a numpy array in a HDF5 friendly format. :param array_like: Input dataset in a type that can be digested by ``numpy.array()`` (`str`, `list`, `numpy.ndarray`…) :return: ``numpy.ndarray`` ready to be written as an HDF5 dataset """ - # simple strings - if isinstance(array_like, string_types): - array_like = numpy.string_(array_like) - - # Ensure our data is a numpy.ndarray - if not isinstance(array_like, (numpy.ndarray, numpy.string_)): - array = numpy.array(array_like) + array = numpy.asarray(array_like) + if numpy.issubdtype(array.dtype, numpy.bytes_): + return numpy.array(array_like, dtype=vlen_bytes) + elif numpy.issubdtype(array.dtype, numpy.str_): + return numpy.array(array_like, dtype=vlen_utf8) else: - array = array_like - - # handle list of strings or numpy array of strings - if not isinstance(array, numpy.string_): - data_kind = array.dtype.kind - # unicode: convert to byte strings - # (http://docs.h5py.org/en/latest/strings.html) - if data_kind.lower() in ["s", "u"]: - array = numpy.asarray(array, dtype=numpy.string_) - - return array + return array class _SafeH5FileWrite(object): @@ -219,150 +210,145 @@ def dicttoh5(treedict, h5file, h5path='/', h5f.create_group(h5path) for key in filter(lambda k: not isinstance(k, tuple), treedict): - if isinstance(treedict[key], dict) and len(treedict[key]): + key_is_group = isinstance(treedict[key], dict) + h5name = h5path + key + + if key_is_group and treedict[key]: # non-empty group: recurse - dicttoh5(treedict[key], h5f, h5path + key, + dicttoh5(treedict[key], h5f, h5name, overwrite_data=overwrite_data, create_dataset_args=create_dataset_args) + continue - elif treedict[key] is None or (isinstance(treedict[key], dict) and - not len(treedict[key])): - if (h5path + key) in h5f: - if overwrite_data is True: - del h5f[h5path + key] - else: - logger.warning('key (%s) already exists. ' - 'Not overwriting.' % (h5path + key)) - continue - # Create empty group - h5f.create_group(h5path + key) + if h5name in h5f: + # key already exists: delete or skip + if overwrite_data is True: + del h5f[h5name] + else: + logger.warning('key (%s) already exists. ' + 'Not overwriting.' % (h5name)) + continue + + value = treedict[key] + if value is None or key_is_group: + # Create empty group + h5f.create_group(h5name) + elif is_link(value): + h5f[h5name] = value else: - ds = _prepare_hdf5_dataset(treedict[key]) + data = _prepare_hdf5_write_value(value) # can't apply filters on scalars (datasets with shape == () ) - if ds.shape == () or create_dataset_args is None: - if h5path + key in h5f: - if overwrite_data is True: - del h5f[h5path + key] - else: - logger.warning('key (%s) already exists. ' - 'Not overwriting.' % (h5path + key)) - continue - - h5f.create_dataset(h5path + key, - data=ds) + if data.shape == () or create_dataset_args is None: + h5f.create_dataset(h5name, + data=data) else: - if h5path + key in h5f: - if overwrite_data is True: - del h5f[h5path + key] - else: - logger.warning('key (%s) already exists. ' - 'Not overwriting.' % (h5path + key)) - continue - - h5f.create_dataset(h5path + key, - data=ds, + h5f.create_dataset(h5name, + data=data, **create_dataset_args) # deal with h5 attributes which have tuples as keys in treedict for key in filter(lambda k: isinstance(k, tuple), treedict): - if (h5path + key[0]) not in h5f: + assert len(key) == 2, "attribute must be defined by 2 values" + h5name = h5path + key[0] + attr_name = key[1] + + if h5name not in h5f: # Create empty group if key for attr does not exist - h5f.create_group(h5path + key[0]) + h5f.create_group(h5name) logger.warning( "key (%s) does not exist. attr %s " - "will be written to ." % (h5path + key[0], key[1]) + "will be written to ." % (h5name, attr_name) ) - if key[1] in h5f[h5path + key[0]].attrs: + if attr_name in h5f[h5name].attrs: if not overwrite_data: logger.warning( "attribute %s@%s already exists. Not overwriting." - "" % (h5path + key[0], key[1]) + "" % (h5name, attr_name) ) continue # Write attribute value = treedict[key] + data = _prepare_hdf5_write_value(value) + h5f[h5name].attrs[attr_name] = data - # Makes list/tuple of str being encoded as vlen unicode array - # Workaround for h5py<2.9.0 (e.g. debian 10). - if (isinstance(value, (list, tuple)) and - numpy.asarray(value).dtype.type == numpy.unicode_): - value = numpy.array(value, dtype=h5py.special_dtype(vlen=str)) - - h5f[h5path + key[0]].attrs[key[1]] = value - -def dicttonx( - treedict, - h5file, - h5path="/", - mode="w", - overwrite_data=False, - create_dataset_args=None, -): - """ - Write a nested dictionary to a HDF5 file, using string keys as member names. - The NeXus convention is used to identify attributes with ``"@"`` character, - therefor the dataset_names should not contain ``"@"``. +def nexus_to_h5_dict(treedict, parents=tuple()): + """The following conversions are applied: + * key with "{name}@{attr_name}" notation: key converted to 2-tuple + * key with ">{url}" notation: strip ">" and convert value to + h5py.SoftLink or h5py.ExternalLink :param treedict: Nested dictionary/tree structure with strings as keys and array-like objects as leafs. The ``"/"`` character can be used to define sub tree. The ``"@"`` character is used to write attributes. + The ``">"`` prefix is used to define links. + :param parents: Needed to resolve up-links (tuple of HDF5 group names) - Detais on all other params can be found in doc of dicttoh5. + :rtype dict: + """ + copy = dict() + for key, value in treedict.items(): + if "@" in key: + key = tuple(key.rsplit("@", 1)) + elif key.startswith(">"): + if isinstance(value, str): + key = key[1:] + first, sep, second = value.partition("::") + if sep: + value = h5py.ExternalLink(first, second) + else: + if ".." in first: + # Up-links not supported: make absolute + parts = [] + for p in list(parents) + first.split("/"): + if not p or p == ".": + continue + elif p == "..": + parts.pop(-1) + else: + parts.append(p) + first = "/" + "/".join(parts) + value = h5py.SoftLink(first) + elif is_link(value): + key = key[1:] + if isinstance(value, dict): + copy[key] = nexus_to_h5_dict(value, parents=parents+(key,)) + else: + copy[key] = value + return copy - Example:: - import numpy - from silx.io.dictdump import dicttonx +def h5_to_nexus_dict(treedict): + """The following conversions are applied: + * 2-tuple key: converted to string ("@" notation) + * h5py.Softlink value: converted to string (">" key prefix) + * h5py.ExternalLink value: converted to string (">" key prefix) - gauss = { - "entry":{ - "title":u"A plot of a gaussian", - "plot": { - "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1., - 0.9, 0.66, 0.39, 0.19, 0.08]), - "x": numpy.arange(0,1.1,.1), - "@signal": "y", - "@axes": "x", - "@NX_class":u"NXdata", - "title:u"Gauss Plot", - }, - "@NX_class":u"NXentry", - "default":"plot", - } - "@NX_class": u"NXroot", - "@default": "entry", - } + :param treedict: Nested dictionary/tree structure with strings as keys + and array-like objects as leafs. The ``"/"`` character can be used + to define sub tree. - dicttonx(gauss,"test.h5") + :rtype dict: """ - - def copy_keys_keep_values(original): - # create a new treedict with with modified keys but keep values - copy = dict() - for key, value in original.items(): - if "@" in key: - newkey = tuple(key.rsplit("@", 1)) - else: - newkey = key - if isinstance(value, dict): - copy[newkey] = copy_keys_keep_values(value) - else: - copy[newkey] = value - return copy - - nxtreedict = copy_keys_keep_values(treedict) - dicttoh5( - nxtreedict, - h5file, - h5path=h5path, - mode=mode, - overwrite_data=overwrite_data, - create_dataset_args=create_dataset_args, - ) + copy = dict() + for key, value in treedict.items(): + if isinstance(key, tuple): + assert len(key)==2, "attribute must be defined by 2 values" + key = "%s@%s" % (key[0], key[1]) + elif is_softlink(value): + key = ">" + key + value = value.path + elif is_externallink(value): + key = ">" + key + value = value.filename + "::" + value.path + if isinstance(value, dict): + copy[key] = h5_to_nexus_dict(value) + else: + copy[key] = value + return copy def _name_contains_string_in_list(name, strlist): @@ -374,7 +360,31 @@ def _name_contains_string_in_list(name, strlist): return False -def h5todict(h5file, path="/", exclude_names=None, asarray=True): +def _handle_error(mode: str, exception, msg: str, *args) -> None: + """Handle errors. + + :param str mode: 'raise', 'log', 'ignore' + :param type exception: Exception class to use in 'raise' mode + :param str msg: Error message template + :param List[str] args: Arguments for error message template + """ + if mode == 'ignore': + return # no-op + elif mode == 'log': + logger.error(msg, *args) + elif mode == 'raise': + raise exception(msg % args) + else: + raise ValueError("Unsupported error handling: %s" % mode) + + +def h5todict(h5file, + path="/", + exclude_names=None, + asarray=True, + dereference_links=True, + include_attributes=False, + errors='raise'): """Read a HDF5 file and return a nested dictionary with the complete file structure and all data. @@ -397,7 +407,7 @@ def h5todict(h5file, path="/", exclude_names=None, asarray=True): .. note:: This function requires `h5py `_ to be installed. - .. note:: If you write a dictionary to a HDF5 file with + .. note:: If you write a dictionary to a HDF5 file with :func:`dicttoh5` and then read it back with :func:`h5todict`, data types are not preserved. All values are cast to numpy arrays before being written to file, and they are read back as numpy arrays (or @@ -412,28 +422,159 @@ def h5todict(h5file, path="/", exclude_names=None, asarray=True): a string in this list will be ignored. Default is None (ignore nothing) :param bool asarray: True (default) to read scalar as arrays, False to read them as scalar + :param bool dereference_links: True (default) to dereference links, False + to preserve the link itself + :param bool include_attributes: False (default) + :param str errors: Handling of errors (HDF5 access issue, broken link,...): + - 'raise' (default): Raise an exception + - 'log': Log as errors + - 'ignore': Ignore errors :return: Nested dictionary """ with _SafeH5FileRead(h5file) as h5f: ddict = {} - for key in h5f[path]: + if path not in h5f: + _handle_error( + errors, KeyError, 'Path "%s" does not exist in file.', path) + return ddict + + try: + root = h5f[path] + except KeyError as e: + if not isinstance(h5f.get(path, getlink=True), h5py.HardLink): + _handle_error(errors, + KeyError, + 'Cannot retrieve path "%s" (broken link)', + path) + else: + _handle_error(errors, KeyError, ', '.join(e.args)) + return ddict + + # Read the attributes of the group + if include_attributes: + attrs = H5pyAttributesReadWrapper(root.attrs) + for aname, avalue in attrs.items(): + ddict[("", aname)] = avalue + # Read the children of the group + for key in root: if _name_contains_string_in_list(key, exclude_names): continue - if is_group(h5f[path + "/" + key]): + h5name = path + "/" + key + # Preserve HDF5 link when requested + if not dereference_links: + lnk = h5f.get(h5name, getlink=True) + if is_link(lnk): + ddict[key] = lnk + continue + + try: + h5obj = h5f[h5name] + except KeyError as e: + if not isinstance(h5f.get(h5name, getlink=True), h5py.HardLink): + _handle_error(errors, + KeyError, + 'Cannot retrieve path "%s" (broken link)', + h5name) + else: + _handle_error(errors, KeyError, ', '.join(e.args)) + continue + + if is_group(h5obj): + # Child is an HDF5 group ddict[key] = h5todict(h5f, - path + "/" + key, + h5name, exclude_names=exclude_names, - asarray=asarray) + asarray=asarray, + dereference_links=dereference_links, + include_attributes=include_attributes) else: - # Read HDF5 datset - data = h5f[path + "/" + key][()] - if asarray: # Convert HDF5 dataset to numpy array - data = numpy.array(data, copy=False) - ddict[key] = data - + # Child is an HDF5 dataset + try: + data = h5py_read_dataset(h5obj) + except OSError: + _handle_error(errors, + OSError, + 'Cannot retrieve dataset "%s"', + h5name) + else: + if asarray: # Convert HDF5 dataset to numpy array + data = numpy.array(data, copy=False) + ddict[key] = data + # Read the attributes of the child + if include_attributes: + attrs = H5pyAttributesReadWrapper(h5obj.attrs) + for aname, avalue in attrs.items(): + ddict[(key, aname)] = avalue return ddict +def dicttonx(treedict, h5file, h5path="/", **kw): + """ + Write a nested dictionary to a HDF5 file, using string keys as member names. + The NeXus convention is used to identify attributes with ``"@"`` character, + therefore the dataset_names should not contain ``"@"``. + + Similarly, links are identified by keys starting with the ``">"`` character. + The corresponding value can be a soft or external link. + + :param treedict: Nested dictionary/tree structure with strings as keys + and array-like objects as leafs. The ``"/"`` character can be used + to define sub tree. The ``"@"`` character is used to write attributes. + The ``">"`` prefix is used to define links. + + The named parameters are passed to dicttoh5. + + Example:: + + import numpy + from silx.io.dictdump import dicttonx + + gauss = { + "entry":{ + "title":u"A plot of a gaussian", + "instrument": { + "@NX_class": u"NXinstrument", + "positioners": { + "@NX_class": u"NXCollection", + "x": numpy.arange(0,1.1,.1) + } + } + "plot": { + "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1., + 0.9, 0.66, 0.39, 0.19, 0.08]), + ">x": "../instrument/positioners/x", + "@signal": "y", + "@axes": "x", + "@NX_class":u"NXdata", + "title:u"Gauss Plot", + }, + "@NX_class": u"NXentry", + "default":"plot", + } + "@NX_class": u"NXroot", + "@default": "entry", + } + + dicttonx(gauss,"test.h5") + """ + parents = tuple(p for p in h5path.split("/") if p) + nxtreedict = nexus_to_h5_dict(treedict, parents=parents) + dicttoh5(nxtreedict, h5file, h5path=h5path, **kw) + + +def nxtodict(h5file, **kw): + """Read a HDF5 file and return a nested dictionary with the complete file + structure and all data. + + As opposed to h5todict, all keys will be strings and no h5py objects are + present in the tree. + + The named parameters are passed to h5todict. + """ + nxtreedict = h5todict(h5file, **kw) + return h5_to_nexus_dict(nxtreedict) + + def dicttojson(ddict, jsonfile, indent=None, mode="w"): """Serialize ``ddict`` as a JSON formatted stream to ``jsonfile``. diff --git a/silx/io/fabioh5.py b/silx/io/fabioh5.py index cfaa0a0..2fd719d 100755 --- a/silx/io/fabioh5.py +++ b/silx/io/fabioh5.py @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2016-2019 European Synchrotron Radiation Facility +# Copyright (C) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -656,13 +656,13 @@ class FabioReader(object): elif result_type.kind == "U": none_value = u"" elif result_type.kind == "f": - none_value = numpy.float("NaN") + none_value = numpy.float64("NaN") elif result_type.kind == "i": - none_value = numpy.int(0) + none_value = numpy.int64(0) elif result_type.kind == "u": - none_value = numpy.int(0) + none_value = numpy.int64(0) elif result_type.kind == "b": - none_value = numpy.bool(False) + none_value = numpy.bool_(False) else: none_value = None diff --git a/silx/io/nxdata/parse.py b/silx/io/nxdata/parse.py index 6bd18d6..b1c1bba 100644 --- a/silx/io/nxdata/parse.py +++ b/silx/io/nxdata/parse.py @@ -45,7 +45,7 @@ import json import numpy import six -from silx.io.utils import is_group, is_file, is_dataset +from silx.io.utils import is_group, is_file, is_dataset, h5py_read_dataset from ._utils import get_attr_as_unicode, INTERPDIM, nxdata_logger, \ get_uncertainties_names, get_signal_name, \ @@ -628,7 +628,7 @@ class NXdata(object): data_dataset_names = [self.signal_name] + self.axes_dataset_names if (title is not None and is_dataset(title) and "title" not in data_dataset_names): - return str(title[()]) + return str(h5py_read_dataset(title)) title = self.group.attrs.get("title") if title is None: diff --git a/silx/io/setup.py b/silx/io/setup.py index 4aaf324..9cafa17 100644 --- a/silx/io/setup.py +++ b/silx/io/setup.py @@ -51,7 +51,7 @@ else: SPECFILE_USE_GNU_SOURCE = int(SPECFILE_USE_GNU_SOURCE) if sys.platform == "win32": - define_macros = [('WIN32', None)] + define_macros = [('WIN32', None), ('SPECFILE_POSIX', None)] elif os.name.lower().startswith('posix'): define_macros = [('SPECFILE_POSIX', None)] # the best choice is to have _GNU_SOURCE defined diff --git a/silx/io/specfile/src/locale_management.c b/silx/io/specfile/src/locale_management.c index 54695f5..0c5f7ca 100644 --- a/silx/io/specfile/src/locale_management.c +++ b/silx/io/specfile/src/locale_management.c @@ -39,6 +39,9 @@ # else # ifdef SPECFILE_POSIX # include +# ifndef LOCALE_NAME_MAX_LENGTH +# define LOCALE_NAME_MAX_LENGTH 85 +# endif # endif # endif #endif @@ -60,7 +63,7 @@ double PyMcaAtof(const char * inputString) #else #ifdef SPECFILE_POSIX char *currentLocaleBuffer; - char localeBuffer[21]; + char localeBuffer[LOCALE_NAME_MAX_LENGTH + 1] = {'\0'}; double result; currentLocaleBuffer = setlocale(LC_NUMERIC, NULL); strcpy(localeBuffer, currentLocaleBuffer); diff --git a/silx/io/test/test_dictdump.py b/silx/io/test/test_dictdump.py index c0b6914..b99116b 100644 --- a/silx/io/test/test_dictdump.py +++ b/silx/io/test/test_dictdump.py @@ -43,6 +43,8 @@ from .. import dictdump from ..dictdump import dicttoh5, dicttojson, dump from ..dictdump import h5todict, load from ..dictdump import logger as dictdump_logger +from ..utils import is_link +from ..utils import h5py_read_dataset def tree(): @@ -58,15 +60,29 @@ city_attrs["Europe"]["France"]["Grenoble"]["inhabitants"] = inhabitants city_attrs["Europe"]["France"]["Grenoble"]["coordinates"] = [45.1830, 5.7196] city_attrs["Europe"]["France"]["Tourcoing"]["area"] +ext_attrs = tree() +ext_attrs["ext_group"]["dataset"] = 10 +ext_filename = "ext.h5" + +link_attrs = tree() +link_attrs["links"]["group"]["dataset"] = 10 +link_attrs["links"]["group"]["relative_softlink"] = h5py.SoftLink("dataset") +link_attrs["links"]["relative_softlink"] = h5py.SoftLink("group/dataset") +link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset") +link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset") + class TestDictToH5(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5") + self.h5_ext_fname = os.path.join(self.tempdir, ext_filename) def tearDown(self): if os.path.exists(self.h5_fname): os.unlink(self.h5_fname) + if os.path.exists(self.h5_ext_fname): + os.unlink(self.h5_ext_fname) os.rmdir(self.tempdir) def testH5CityAttrs(self): @@ -201,31 +217,129 @@ class TestDictToH5(unittest.TestCase): self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11) self.assertEqual(h5file["group/group"].attrs['attr'], 12) + def testLinks(self): + with h5py.File(self.h5_ext_fname, "w") as h5file: + dictdump.dicttoh5(ext_attrs, h5file) + with h5py.File(self.h5_fname, "w") as h5file: + dictdump.dicttoh5(link_attrs, h5file) + with h5py.File(self.h5_fname, "r") as h5file: + self.assertEqual(h5file["links/group/dataset"][()], 10) + self.assertEqual(h5file["links/group/relative_softlink"][()], 10) + self.assertEqual(h5file["links/relative_softlink"][()], 10) + self.assertEqual(h5file["links/absolute_softlink"][()], 10) + self.assertEqual(h5file["links/external_link"][()], 10) + + def testDumpNumpyArray(self): + ddict = { + 'darks': { + '0': numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16) + } + } + with h5py.File(self.h5_fname, "w") as h5file: + dictdump.dicttoh5(ddict, h5file) + with h5py.File(self.h5_fname, "r") as h5file: + numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]), + ddict['darks']['0']) + + +class TestH5ToDict(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5") + self.h5_ext_fname = os.path.join(self.tempdir, ext_filename) + dicttoh5(city_attrs, self.h5_fname) + dicttoh5(link_attrs, self.h5_fname, mode="a") + dicttoh5(ext_attrs, self.h5_ext_fname) + + def tearDown(self): + if os.path.exists(self.h5_fname): + os.unlink(self.h5_fname) + if os.path.exists(self.h5_ext_fname): + os.unlink(self.h5_ext_fname) + os.rmdir(self.tempdir) + + def testExcludeNames(self): + ddict = h5todict(self.h5_fname, path="/Europe/France", + exclude_names=["ourcoing", "inhab", "toto"]) + self.assertNotIn("Tourcoing", ddict) + self.assertIn("Grenoble", ddict) + + self.assertNotIn("inhabitants", ddict["Grenoble"]) + self.assertIn("coordinates", ddict["Grenoble"]) + self.assertIn("area", ddict["Grenoble"]) + + def testAsArrayTrue(self): + """Test with asarray=True, the default""" + ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble") + self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants))) + + def testAsArrayFalse(self): + """Test with asarray=False""" + ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False) + self.assertEqual(ddict["inhabitants"], inhabitants) + + def testDereferenceLinks(self): + ddict = h5todict(self.h5_fname, path="links", dereference_links=True) + self.assertTrue(ddict["absolute_softlink"], 10) + self.assertTrue(ddict["relative_softlink"], 10) + self.assertTrue(ddict["external_link"], 10) + self.assertTrue(ddict["group"]["relative_softlink"], 10) + + def testPreserveLinks(self): + ddict = h5todict(self.h5_fname, path="links", dereference_links=False) + self.assertTrue(is_link(ddict["absolute_softlink"])) + self.assertTrue(is_link(ddict["relative_softlink"])) + self.assertTrue(is_link(ddict["external_link"])) + self.assertTrue(is_link(ddict["group"]["relative_softlink"])) + + def testStrings(self): + ddict = {"dset_bytes": b"bytes", + "dset_utf8": "utf8", + "dset_2bytes": [b"bytes", b"bytes"], + "dset_2utf8": ["utf8", "utf8"], + ("", "attr_bytes"): b"bytes", + ("", "attr_utf8"): "utf8", + ("", "attr_2bytes"): [b"bytes", b"bytes"], + ("", "attr_2utf8"): ["utf8", "utf8"]} + dicttoh5(ddict, self.h5_fname, mode="w") + adict = h5todict(self.h5_fname, include_attributes=True, asarray=False) + self.assertEqual(ddict["dset_bytes"], adict["dset_bytes"]) + self.assertEqual(ddict["dset_utf8"], adict["dset_utf8"]) + self.assertEqual(ddict[("", "attr_bytes")], adict[("", "attr_bytes")]) + self.assertEqual(ddict[("", "attr_utf8")], adict[("", "attr_utf8")]) + numpy.testing.assert_array_equal(ddict["dset_2bytes"], adict["dset_2bytes"]) + numpy.testing.assert_array_equal(ddict["dset_2utf8"], adict["dset_2utf8"]) + numpy.testing.assert_array_equal(ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")]) + numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")]) + class TestDictToNx(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() self.h5_fname = os.path.join(self.tempdir, "nx.h5") + self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5") def tearDown(self): if os.path.exists(self.h5_fname): os.unlink(self.h5_fname) + if os.path.exists(self.h5_ext_fname): + os.unlink(self.h5_ext_fname) os.rmdir(self.tempdir) def testAttributes(self): """Any kind of attribute can be described""" ddict = { - "group": {"datatset": "hmmm", "@group_attr": 10}, - "dataset": "aaaaaaaaaaaaaaa", + "group": {"dataset": 100, "@group_attr1": 10}, + "dataset": 200, "@root_attr": 11, - "dataset@dataset_attr": 12, + "dataset@dataset_attr": "12", "group@group_attr2": 13, } with h5py.File(self.h5_fname, "w") as h5file: dictdump.dicttonx(ddict, h5file) - self.assertEqual(h5file["group"].attrs['group_attr'], 10) + self.assertEqual(h5file["group"].attrs['group_attr1'], 10) self.assertEqual(h5file.attrs['root_attr'], 11) - self.assertEqual(h5file["dataset"].attrs['dataset_attr'], 12) + self.assertEqual(h5file["dataset"].attrs['dataset_attr'], "12") self.assertEqual(h5file["group"].attrs['group_attr2'], 13) def testKeyOrder(self): @@ -280,36 +394,120 @@ class TestDictToNx(unittest.TestCase): self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11) self.assertEqual(h5file["group/group"].attrs['attr'], 12) - -class TestH5ToDict(unittest.TestCase): + def testLinks(self): + ddict = {"ext_group": {"dataset": 10}} + dictdump.dicttonx(ddict, self.h5_ext_fname) + ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"}, + ">relative_softlink": "group/dataset", + ">absolute_softlink": "/links/group/dataset", + ">external_link": "nx_ext.h5::/ext_group/dataset"}} + dictdump.dicttonx(ddict, self.h5_fname) + with h5py.File(self.h5_fname, "r") as h5file: + self.assertEqual(h5file["links/group/dataset"][()], 10) + self.assertEqual(h5file["links/group/relative_softlink"][()], 10) + self.assertEqual(h5file["links/relative_softlink"][()], 10) + self.assertEqual(h5file["links/absolute_softlink"][()], 10) + self.assertEqual(h5file["links/external_link"][()], 10) + + def testUpLinks(self): + ddict = {"data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}}, + "links": {"group": {"subgroup": {">relative_softlink": "../../../data/group/dataset"}}}} + dictdump.dicttonx(ddict, self.h5_fname) + with h5py.File(self.h5_fname, "r") as h5file: + self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10) + + +class TestNxToDict(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() - self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5") - dicttoh5(city_attrs, self.h5_fname) + self.h5_fname = os.path.join(self.tempdir, "nx.h5") + self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5") def tearDown(self): - os.unlink(self.h5_fname) + if os.path.exists(self.h5_fname): + os.unlink(self.h5_fname) + if os.path.exists(self.h5_ext_fname): + os.unlink(self.h5_ext_fname) os.rmdir(self.tempdir) - def testExcludeNames(self): - ddict = h5todict(self.h5_fname, path="/Europe/France", - exclude_names=["ourcoing", "inhab", "toto"]) - self.assertNotIn("Tourcoing", ddict) - self.assertIn("Grenoble", ddict) - - self.assertNotIn("inhabitants", ddict["Grenoble"]) - self.assertIn("coordinates", ddict["Grenoble"]) - self.assertIn("area", ddict["Grenoble"]) - - def testAsArrayTrue(self): - """Test with asarray=True, the default""" - ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble") - self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants))) - - def testAsArrayFalse(self): - """Test with asarray=False""" - ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False) - self.assertEqual(ddict["inhabitants"], inhabitants) + def testAttributes(self): + """Any kind of attribute can be described""" + ddict = { + "group": {"dataset": 100, "@group_attr1": 10}, + "dataset": 200, + "@root_attr": 11, + "dataset@dataset_attr": "12", + "group@group_attr2": 13, + } + dictdump.dicttonx(ddict, self.h5_fname) + ddict = dictdump.nxtodict(self.h5_fname, include_attributes=True) + self.assertEqual(ddict["group"]["@group_attr1"], 10) + self.assertEqual(ddict["@root_attr"], 11) + self.assertEqual(ddict["dataset@dataset_attr"], "12") + self.assertEqual(ddict["group"]["@group_attr2"], 13) + + def testDereferenceLinks(self): + """Write links and dereference on read""" + ddict = {"ext_group": {"dataset": 10}} + dictdump.dicttonx(ddict, self.h5_ext_fname) + ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"}, + ">relative_softlink": "group/dataset", + ">absolute_softlink": "/links/group/dataset", + ">external_link": "nx_ext.h5::/ext_group/dataset"}} + dictdump.dicttonx(ddict, self.h5_fname) + + ddict = dictdump.h5todict(self.h5_fname, dereference_links=True) + self.assertTrue(ddict["links"]["absolute_softlink"], 10) + self.assertTrue(ddict["links"]["relative_softlink"], 10) + self.assertTrue(ddict["links"]["external_link"], 10) + self.assertTrue(ddict["links"]["group"]["relative_softlink"], 10) + + def testPreserveLinks(self): + """Write/read links""" + ddict = {"ext_group": {"dataset": 10}} + dictdump.dicttonx(ddict, self.h5_ext_fname) + ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"}, + ">relative_softlink": "group/dataset", + ">absolute_softlink": "/links/group/dataset", + ">external_link": "nx_ext.h5::/ext_group/dataset"}} + dictdump.dicttonx(ddict, self.h5_fname) + + ddict = dictdump.nxtodict(self.h5_fname, dereference_links=False) + self.assertTrue(ddict["links"][">absolute_softlink"], "dataset") + self.assertTrue(ddict["links"][">relative_softlink"], "group/dataset") + self.assertTrue(ddict["links"][">external_link"], "/links/group/dataset") + self.assertTrue(ddict["links"]["group"][">relative_softlink"], "nx_ext.h5::/ext_group/datase") + + def testNotExistingPath(self): + """Test converting not existing path""" + with h5py.File(self.h5_fname, 'a') as f: + f['data'] = 1 + + ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='ignore') + self.assertFalse(ddict) + + with TestLogging(dictdump_logger, error=1): + ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='log') + self.assertFalse(ddict) + + with self.assertRaises(KeyError): + h5todict(self.h5_fname, path="/I/am/not/a/path", errors='raise') + + def testBrokenLinks(self): + """Test with broken links""" + with h5py.File(self.h5_fname, 'a') as f: + f["/Mars/BrokenSoftLink"] = h5py.SoftLink("/Idontexists") + f["/Mars/BrokenExternalLink"] = h5py.ExternalLink("notexistingfile.h5", "/Idontexists") + + ddict = h5todict(self.h5_fname, path="/Mars", errors='ignore') + self.assertFalse(ddict) + + with TestLogging(dictdump_logger, error=2): + ddict = h5todict(self.h5_fname, path="/Mars", errors='log') + self.assertFalse(ddict) + + with self.assertRaises(KeyError): + h5todict(self.h5_fname, path="/Mars", errors='raise') class TestDictToJson(unittest.TestCase): @@ -436,6 +634,7 @@ def suite(): test_suite.addTest(loadTests(TestDictToNx)) test_suite.addTest(loadTests(TestDictToJson)) test_suite.addTest(loadTests(TestH5ToDict)) + test_suite.addTest(loadTests(TestNxToDict)) return test_suite diff --git a/silx/io/test/test_spectoh5.py b/silx/io/test/test_spectoh5.py index c3f03e9..903a62c 100644 --- a/silx/io/test/test_spectoh5.py +++ b/silx/io/test/test_spectoh5.py @@ -33,6 +33,7 @@ import h5py from ..spech5 import SpecH5, SpecH5Group from ..convert import convert, write_to_h5 +from ..utils import h5py_read_dataset __authors__ = ["P. Knobel"] __license__ = "MIT" @@ -129,7 +130,7 @@ class TestConvertSpecHDF5(unittest.TestCase): def testTitle(self): """Test the value of a dataset""" - title12 = self.h5f["/1.2/title"][()] + title12 = h5py_read_dataset(self.h5f["/1.2/title"]) self.assertEqual(title12, u"aaaaaa") diff --git a/silx/io/test/test_url.py b/silx/io/test/test_url.py index e68c67a..114f6a7 100644 --- a/silx/io/test/test_url.py +++ b/silx/io/test/test_url.py @@ -152,6 +152,16 @@ class TestDataUrl(unittest.TestCase): expected = [True, True, None, "/a.h5", "/b", (5, 1)] self.assertUrl(url, expected) + def test_slice2(self): + url = DataUrl("/a.h5?path=/b&slice=2:5") + expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)] + self.assertUrl(url, expected) + + def test_slice3(self): + url = DataUrl("/a.h5?path=/b&slice=::2") + expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)] + self.assertUrl(url, expected) + def test_slice_ellipsis(self): url = DataUrl("/a.h5?path=/b&slice=...") expected = [True, True, None, "/a.h5", "/b", (Ellipsis, )] diff --git a/silx/io/test/test_utils.py b/silx/io/test/test_utils.py index 6c70636..13ab532 100644 --- a/silx/io/test/test_utils.py +++ b/silx/io/test/test_utils.py @@ -33,6 +33,7 @@ import unittest import sys from .. import utils +from ..._version import calc_hexversion import silx.io.url import h5py @@ -40,11 +41,9 @@ from ..utils import h5ls import fabio - __authors__ = ["P. Knobel"] __license__ = "MIT" -__date__ = "12/02/2018" - +__date__ = "03/12/2020" expected_spec1 = r"""#F .* #D .* @@ -67,6 +66,28 @@ expected_spec2 = expected_spec1 + r""" 2 8\.00 3 9\.00 """ + +expected_spec2reg = r"""#F .* +#D .* + +#S 1 Ordinate1 +#D .* +#N 3 +#L Abscissa Ordinate1 Ordinate2 +1 4\.00 7\.00 +2 5\.00 8\.00 +3 6\.00 9\.00 +""" + +expected_spec2irr = expected_spec1 + r""" +#S 2 Ordinate2 +#D .* +#N 2 +#L Abscissa Ordinate2 +1 7\.00 +2 8\.00 +""" + expected_csv = r"""Abscissa;Ordinate1;Ordinate2 1;4\.00;7\.00e\+00 2;5\.00;8\.00e\+00 @@ -83,6 +104,7 @@ expected_csv2 = r"""x;y0;y1 class TestSave(unittest.TestCase): """Test saving curves as SpecFile: """ + def setUp(self): self.tempdir = tempfile.mkdtemp() self.spec_fname = os.path.join(self.tempdir, "savespec.dat") @@ -92,6 +114,7 @@ class TestSave(unittest.TestCase): self.x = [1, 2, 3] self.xlab = "Abscissa" self.y = [[4, 5, 6], [7, 8, 9]] + self.y_irr = [[4, 5, 6], [7, 8]] self.ylabs = ["Ordinate1", "Ordinate2"] def tearDown(self): @@ -103,13 +126,6 @@ class TestSave(unittest.TestCase): os.unlink(self.npy_fname) shutil.rmtree(self.tempdir) - def assertRegex(self, *args, **kwargs): - # Python 2 compatibility - if sys.version_info.major >= 3: - return super(TestSave, self).assertRegex(*args, **kwargs) - else: - return self.assertRegexpMatches(*args, **kwargs) - def test_save_csv(self): utils.save1D(self.csv_fname, self.x, self.y, xlabel=self.xlab, ylabels=self.ylabs, @@ -145,7 +161,6 @@ class TestSave(unittest.TestCase): specf = open(self.spec_fname) actual_spec = specf.read() specf.close() - self.assertRegex(actual_spec, expected_spec1) def test_savespec_file_handle(self): @@ -165,18 +180,30 @@ class TestSave(unittest.TestCase): specf = open(self.spec_fname) actual_spec = specf.read() specf.close() - self.assertRegex(actual_spec, expected_spec2) - def test_save_spec(self): - """Save SpecFile using save()""" + def test_save_spec_reg(self): + """Save SpecFile using save() on a regular pattern""" utils.save1D(self.spec_fname, self.x, self.y, xlabel=self.xlab, ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"]) specf = open(self.spec_fname) actual_spec = specf.read() specf.close() - self.assertRegex(actual_spec, expected_spec2) + + self.assertRegex(actual_spec, expected_spec2reg) + + def test_save_spec_irr(self): + """Save SpecFile using save() on an irregular pattern""" + # invalid test case ?! + return + utils.save1D(self.spec_fname, self.x, self.y_irr, xlabel=self.xlab, + ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"]) + + specf = open(self.spec_fname) + actual_spec = specf.read() + specf.close() + self.assertRegex(actual_spec, expected_spec2irr) def test_save_csv_no_labels(self): """Save csv using save(), with autoheader=True but @@ -217,6 +244,7 @@ class TestH5Ls(unittest.TestCase): """ + def assertMatchAnyStringInList(self, pattern, list_of_strings): for string_ in list_of_strings: if re.match(pattern, string_): @@ -395,6 +423,7 @@ class TestOpen(unittest.TestCase): class TestNodes(unittest.TestCase): """Test `silx.io.utils.is_` functions.""" + def test_real_h5py_objects(self): name = tempfile.mktemp(suffix=".h5") try: @@ -417,45 +446,60 @@ class TestNodes(unittest.TestCase): os.unlink(name) def test_h5py_like_file(self): + class Foo(object): + def __init__(self): self.h5_class = utils.H5Type.FILE + obj = Foo() self.assertTrue(utils.is_file(obj)) self.assertTrue(utils.is_group(obj)) self.assertFalse(utils.is_dataset(obj)) def test_h5py_like_group(self): + class Foo(object): + def __init__(self): self.h5_class = utils.H5Type.GROUP + obj = Foo() self.assertFalse(utils.is_file(obj)) self.assertTrue(utils.is_group(obj)) self.assertFalse(utils.is_dataset(obj)) def test_h5py_like_dataset(self): + class Foo(object): + def __init__(self): self.h5_class = utils.H5Type.DATASET + obj = Foo() self.assertFalse(utils.is_file(obj)) self.assertFalse(utils.is_group(obj)) self.assertTrue(utils.is_dataset(obj)) def test_bad(self): + class Foo(object): + def __init__(self): pass + obj = Foo() self.assertFalse(utils.is_file(obj)) self.assertFalse(utils.is_group(obj)) self.assertFalse(utils.is_dataset(obj)) def test_bad_api(self): + class Foo(object): + def __init__(self): self.h5_class = int + obj = Foo() self.assertFalse(utils.is_file(obj)) self.assertFalse(utils.is_group(obj)) @@ -513,18 +557,20 @@ class TestGetData(unittest.TestCase): def test_hdf5_array(self): url = "silx:%s?/group/group/array" % self.h5_filename data = utils.get_data(url=url) - self.assertEqual(data.shape, (5, )) + self.assertEqual(data.shape, (5,)) self.assertEqual(data[0], 1) def test_hdf5_array_slice(self): url = "silx:%s?path=/group/group/array2d&slice=1" % self.h5_filename data = utils.get_data(url=url) - self.assertEqual(data.shape, (5, )) + self.assertEqual(data.shape, (5,)) self.assertEqual(data[0], 6) def test_hdf5_array_slice_out_of_range(self): url = "silx:%s?path=/group/group/array2d&slice=5" % self.h5_filename - self.assertRaises(ValueError, utils.get_data, url) + # ValueError: h5py 2.x + # IndexError: h5py 3.x + self.assertRaises((ValueError, IndexError), utils.get_data, url) def test_edf_using_silx(self): url = "silx:%s?/scan_0/instrument/detector_0/data" % self.edf_filename @@ -568,14 +614,15 @@ class TestGetData(unittest.TestCase): def _h5_py_version_older_than(version): - v_majeur, v_mineur, v_micro = h5py.version.version.split('.')[:3] - r_majeur, r_mineur, r_micro = version.split('.') - return v_majeur >= r_majeur and v_mineur >= r_mineur + v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]] + r_majeur, r_mineur, r_micro = [int(i) for i in version.split('.')] + return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(r_majeur, r_mineur, r_micro) @unittest.skipUnless(_h5_py_version_older_than('2.9.0'), 'h5py version < 2.9.0') class TestRawFileToH5(unittest.TestCase): """Test conversion of .vol file to .h5 external dataset""" + def setUp(self): self.tempdir = tempfile.mkdtemp() self._vol_file = os.path.join(self.tempdir, 'test_vol.vol') @@ -589,7 +636,7 @@ class TestRawFileToH5(unittest.TestCase): assert os.path.exists(self._vol_file + '.npy') os.rename(self._vol_file + '.npy', self._vol_file) self.h5_file = os.path.join(self.tempdir, 'test_h5.h5') - self.external_dataset_path= '/root/my_external_dataset' + self.external_dataset_path = '/root/my_external_dataset' self._data_url = silx.io.url.DataUrl(file_path=self.h5_file, data_path=self.external_dataset_path) with open(self._file_info, 'w') as _fi: @@ -672,6 +719,158 @@ class TestRawFileToH5(unittest.TestCase): shape=self._dataset_shape)) +class TestH5Strings(unittest.TestCase): + """Test HDF5 str and bytes writing and reading""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.vlenstr = h5py.special_dtype(vlen=str) + cls.vlenbytes = h5py.special_dtype(vlen=bytes) + try: + cls.unicode = unicode + except NameError: + cls.unicode = str + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir) + + def setUp(self): + self.file = h5py.File(os.path.join(self.tempdir, 'file.h5'), mode="w") + + def tearDown(self): + self.file.close() + + @classmethod + def _make_array(cls, value, n): + if isinstance(value, bytes): + dtype = cls.vlenbytes + elif isinstance(value, cls.unicode): + dtype = cls.vlenstr + else: + return numpy.array([value] * n) + return numpy.array([value] * n, dtype=dtype) + + @classmethod + def _get_charset(cls, value): + if isinstance(value, bytes): + return h5py.h5t.CSET_ASCII + elif isinstance(value, cls.unicode): + return h5py.h5t.CSET_UTF8 + else: + return None + + def _check_dataset(self, value, result=None): + # Write+read scalar + if result: + decode_ascii = True + else: + decode_ascii = False + result = value + charset = self._get_charset(value) + self.file["data"] = value + data = utils.h5py_read_dataset(self.file["data"], decode_ascii=decode_ascii) + assert type(data) == type(result), data + assert data == result, data + if charset: + assert self.file["data"].id.get_type().get_cset() == charset + + # Write+read variable length + self.file["vlen_data"] = self._make_array(value, 2) + data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii, index=0) + assert type(data) == type(result), data + assert data == result, data + data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii) + numpy.testing.assert_array_equal(data, [result] * 2) + if charset: + assert self.file["vlen_data"].id.get_type().get_cset() == charset + + def _check_attribute(self, value, result=None): + if result: + decode_ascii = True + else: + decode_ascii = False + result = value + self.file.attrs["data"] = value + data = utils.h5py_read_attribute(self.file.attrs, "data", decode_ascii=decode_ascii) + assert type(data) == type(result), data + assert data == result, data + + self.file.attrs["vlen_data"] = self._make_array(value, 2) + data = utils.h5py_read_attribute(self.file.attrs, "vlen_data", decode_ascii=decode_ascii) + assert type(data[0]) == type(result), data[0] + assert data[0] == result, data[0] + numpy.testing.assert_array_equal(data, [result] * 2) + + data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)["vlen_data"] + assert type(data[0]) == type(result), data[0] + assert data[0] == result, data[0] + numpy.testing.assert_array_equal(data, [result] * 2) + + def test_dataset_ascii_bytes(self): + self._check_dataset(b"abc") + + def test_attribute_ascii_bytes(self): + self._check_attribute(b"abc") + + def test_dataset_ascii_bytes_decode(self): + self._check_dataset(b"abc", result="abc") + + def test_attribute_ascii_bytes_decode(self): + self._check_attribute(b"abc", result="abc") + + def test_dataset_ascii_str(self): + self._check_dataset("abc") + + def test_attribute_ascii_str(self): + self._check_attribute("abc") + + def test_dataset_utf8_str(self): + self._check_dataset("\u0101bc") + + def test_attribute_utf8_str(self): + self._check_attribute("\u0101bc") + + def test_dataset_utf8_bytes(self): + # 0xC481 is the byte representation of U+0101 + self._check_dataset(b"\xc4\x81bc") + + def test_attribute_utf8_bytes(self): + # 0xC481 is the byte representation of U+0101 + self._check_attribute(b"\xc4\x81bc") + + def test_dataset_utf8_bytes_decode(self): + # 0xC481 is the byte representation of U+0101 + self._check_dataset(b"\xc4\x81bc", result="\u0101bc") + + def test_attribute_utf8_bytes_decode(self): + # 0xC481 is the byte representation of U+0101 + self._check_attribute(b"\xc4\x81bc", result="\u0101bc") + + def test_dataset_latin1_bytes(self): + # extended ascii character 0xE4 + self._check_dataset(b"\xe423") + + def test_attribute_latin1_bytes(self): + # extended ascii character 0xE4 + self._check_attribute(b"\xe423") + + def test_dataset_latin1_bytes_decode(self): + # U+DCE4: surrogate for extended ascii character 0xE4 + self._check_dataset(b"\xe423", result="\udce423") + + def test_attribute_latin1_bytes_decode(self): + # U+DCE4: surrogate for extended ascii character 0xE4 + self._check_attribute(b"\xe423", result="\udce423") + + def test_dataset_no_string(self): + self._check_dataset(numpy.int64(10)) + + def test_attribute_no_string(self): + self._check_attribute(numpy.int64(10)) + + def suite(): loadTests = unittest.defaultTestLoader.loadTestsFromTestCase test_suite = unittest.TestSuite() @@ -681,6 +880,7 @@ def suite(): test_suite.addTest(loadTests(TestNodes)) test_suite.addTest(loadTests(TestGetData)) test_suite.addTest(loadTests(TestRawFileToH5)) + test_suite.addTest(loadTests(TestH5Strings)) return test_suite diff --git a/silx/io/url.py b/silx/io/url.py index 7607ae5..044977c 100644 --- a/silx/io/url.py +++ b/silx/io/url.py @@ -178,8 +178,20 @@ class DataUrl(object): def str_to_slice(string): if string == "...": return Ellipsis - elif string == ":": - return slice(None) + elif ':' in string: + if string == ":": + return slice(None) + else: + def get_value(my_str): + if my_str in ('', None): + return None + else: + return int(my_str) + sss = string.split(':') + start = get_value(sss[0]) + stop = get_value(sss[1] if len(sss) > 1 else None) + step = get_value(sss[2] if len(sss) > 2 else None) + return slice(start, stop, step) else: return int(string) @@ -201,7 +213,10 @@ class DataUrl(object): :param str path: Path representing the URL. """ self.__path = path - path = path.replace("::", "?", 1) + # only replace if ? not here already. Otherwise can mess sith + # data_slice if == ::2 for example + if '?' not in path: + path = path.replace("::", "?", 1) url = parse.urlparse(path) is_valid = True diff --git a/silx/io/utils.py b/silx/io/utils.py index 5da344d..12e9a7e 100644 --- a/silx/io/utils.py +++ b/silx/io/utils.py @@ -25,8 +25,7 @@ __authors__ = ["P. Knobel", "V. Valls"] __license__ = "MIT" -__date__ = "18/04/2018" - +__date__ = "03/12/2020" import enum import os.path @@ -40,18 +39,19 @@ import six from silx.utils.proxy import Proxy import silx.io.url +from .._version import calc_hexversion import h5py +import h5py.h5t +import h5py.h5a try: import h5pyd except ImportError as e: h5pyd = None - logger = logging.getLogger(__name__) - NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"] """List of possible extensions for HDF5 file formats.""" @@ -190,34 +190,46 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None, if xlabel is None: xlabel = "x" if ylabels is None: - if len(numpy.array(y).shape) > 1: + if numpy.array(y).ndim > 1: ylabels = ["y%d" % i for i in range(len(y))] else: ylabels = ["y"] elif isinstance(ylabels, (list, tuple)): # if ylabels is provided as a list, every element must # be a string - ylabels = [ylabels[i] if ylabels[i] is not None else "y%d" % i - for i in range(len(ylabels))] + ylabels = [ylabel if isinstance(ylabel, string_types) else "y%d" % i + for ylabel in ylabels] if filetype.lower() == "spec": - y_array = numpy.asarray(y) - - # make sure y_array is a 2D array even for a single curve - if len(y_array.shape) == 1: - y_array = y_array.reshape(1, y_array.shape[0]) - elif len(y_array.shape) > 2 or len(y_array.shape) < 1: - raise IndexError("y must be a 1D or 2D array") - - # First curve - specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt, - scan_number=1, mode="w", write_file_header=True, - close_file=False) - # Other curves - for i in range(1, y_array.shape[0]): - specf = savespec(specf, x, y_array[i], xlabel, ylabels[i], - fmt=fmt, scan_number=i + 1, mode="w", - write_file_header=False, close_file=False) + # Check if we have regular data: + ref = len(x) + regular = True + for one_y in y: + regular &= len(one_y) == ref + if regular: + if isinstance(fmt, (list, tuple)) and len(fmt) < (len(ylabels) + 1): + fmt = fmt + [fmt[-1] * (1 + len(ylabels) - len(fmt))] + specf = savespec(fname, x, y, xlabel, ylabels, fmt=fmt, + scan_number=1, mode="w", write_file_header=True, + close_file=False) + else: + y_array = numpy.asarray(y) + # make sure y_array is a 2D array even for a single curve + if y_array.ndim == 1: + y_array.shape = 1, -1 + elif y_array.ndim not in [1, 2]: + raise IndexError("y must be a 1D or 2D array") + + # First curve + specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt, + scan_number=1, mode="w", write_file_header=True, + close_file=False) + # Other curves + for i in range(1, y_array.shape[0]): + specf = savespec(specf, x, y_array[i], xlabel, ylabels[i], + fmt=fmt, scan_number=i + 1, mode="w", + write_file_header=False, close_file=False) + # close file if we created it if not hasattr(fname, "write"): specf.close() @@ -307,9 +319,11 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g", or append mode. If a file name is provided, a new file is open in write mode (existing file with the same name will be lost) :param x: 1D-Array (or list) of abscissa values - :param y: 1D-array (or list) of ordinates values + :param y: 1D-array (or list), or list of them of ordinates values. + All dataset must have the same length as x :param xlabel: Abscissa label (default ``"X"``) - :param ylabel: Ordinate label + :param ylabel: Ordinate label, may be a list of labels when multiple curves + are to be saved together. :param fmt: Format string for data. You can specify a short format string that defines a single format for both ``x`` and ``y`` values, or a list of two different format strings (e.g. ``["%d", "%.7g"]``). @@ -333,40 +347,51 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g", x_array = numpy.asarray(x) y_array = numpy.asarray(y) + if y_array.ndim > 2: + raise IndexError("Y columns must have be packed as 1D") - if y_array.shape[0] != x_array.shape[0]: + if y_array.shape[-1] != x_array.shape[0]: raise IndexError("X and Y columns must have the same length") + if y_array.ndim == 2: + assert isinstance(ylabel, (list, tuple)) + assert y_array.shape[0] == len(ylabel) + labels = (xlabel, *ylabel) + else: + labels = (xlabel, ylabel) + data = numpy.vstack((x_array, y_array)) + ncol = data.shape[0] + assert len(labels) == ncol + + print(xlabel, ylabel, fmt, ncol, x_array, y_array) if isinstance(fmt, string_types) and fmt.count("%") == 1: - full_fmt_string = fmt + " " + fmt + "\n" - elif isinstance(fmt, (list, tuple)) and len(fmt) == 2: - full_fmt_string = " ".join(fmt) + "\n" + full_fmt_string = " ".join([fmt] * ncol) + elif isinstance(fmt, (list, tuple)) and len(fmt) == ncol: + full_fmt_string = " ".join(fmt) else: - raise ValueError("fmt must be a single format string or a list of " + - "two format strings") + raise ValueError("`fmt` must be a single format string or a list of " + + "format strings with as many format as ncolumns") if not hasattr(specfile, "write"): f = builtin_open(specfile, mode) else: f = specfile - output = "" - - current_date = "#D %s\n" % (time.ctime(time.time())) - + current_date = "#D %s" % (time.ctime(time.time())) if write_file_header: - output += "#F %s\n" % f.name - output += current_date - output += "\n" - - output += "#S %d %s\n" % (scan_number, ylabel) - output += current_date - output += "#N 2\n" - output += "#L %s %s\n" % (xlabel, ylabel) - for i in range(y_array.shape[0]): - output += full_fmt_string % (x_array[i], y_array[i]) - output += "\n" + lines = [ "#F %s" % f.name, current_date, ""] + else: + lines = [""] + lines += [ "#S %d %s" % (scan_number, labels[1]), + current_date, + "#N %d" % ncol, + "#L " + " ".join(labels)] + + for i in data.T: + lines.append(full_fmt_string % tuple(i)) + lines.append("") + output = "\n".join(lines) f.write(output.encode()) if close_file: @@ -406,7 +431,7 @@ def h5ls(h5group, lvl=0): if is_group(h5group): h5f = h5group elif isinstance(h5group, string_types): - h5f = open(h5group) # silx.io.open + h5f = open(h5group) # silx.io.open else: raise TypeError("h5group must be a hdf5-like group object or a file name.") @@ -735,6 +760,26 @@ def is_softlink(obj): return t == H5Type.SOFT_LINK +def is_externallink(obj): + """ + True if the object is a h5py.ExternalLink-like object. + + :param obj: An object + """ + t = get_h5_class(obj) + return t == H5Type.EXTERNAL_LINK + + +def is_link(obj): + """ + True if the object is a h5py link-like object. + + :param obj: An object + """ + t = get_h5_class(obj) + return t in {H5Type.SOFT_LINK, H5Type.EXTERNAL_LINK} + + def get_data(url): """Returns a numpy data from an URL. @@ -791,16 +836,16 @@ def get_data(url): raise ValueError("Data path from URL '%s' is not a dataset" % url.path()) if data_slice is not None: - data = data[data_slice] + data = h5py_read_dataset(data, index=data_slice) else: # works for scalar and array - data = data[()] + data = h5py_read_dataset(data) elif url.scheme() == "fabio": import fabio data_slice = url.data_slice() if data_slice is None: - data_slice = (0, ) + data_slice = (0,) if data_slice is None or len(data_slice) != 1: raise ValueError("Fabio slice expect a single frame, but %s found" % data_slice) index = data_slice[0] @@ -844,8 +889,8 @@ def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype, """ assert isinstance(output_url, silx.io.url.DataUrl) assert isinstance(shape, (tuple, list)) - v_majeur, v_mineur, v_micro = h5py.version.version.split('.') - if v_majeur <= '2' and v_mineur < '9': + v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]] + if calc_hexversion(v_majeur, v_mineur, v_micro)< calc_hexversion(2,9,0): raise Exception('h5py >= 2.9 should be installed to access the ' 'external feature.') @@ -915,3 +960,183 @@ def vol_to_h5_external_dataset(vol_file, output_url, info_file=None, shape=shape, dtype=vol_dtype, overwrite=overwrite) + + +def h5py_decode_value(value, encoding="utf-8", errors="surrogateescape"): + """Keep bytes when value cannot be decoded + + :param value: bytes or array of bytes + :param encoding str: + :param errors str: + """ + try: + if numpy.isscalar(value): + return value.decode(encoding, errors=errors) + str_item = [b.decode(encoding, errors=errors) for b in value.flat] + return numpy.array(str_item, dtype=object).reshape(value.shape) + except UnicodeDecodeError: + return value + + +def h5py_encode_value(value, encoding="utf-8", errors="surrogateescape"): + """Keep string when value cannot be encoding + + :param value: string or array of strings + :param encoding str: + :param errors str: + """ + try: + if numpy.isscalar(value): + return value.encode(encoding, errors=errors) + bytes_item = [s.encode(encoding, errors=errors) for s in value.flat] + return numpy.array(bytes_item, dtype=object).reshape(value.shape) + except UnicodeEncodeError: + return value + + +class H5pyDatasetReadWrapper: + """Wrapper to handle H5T_STRING decoding on-the-fly when reading + a dataset. Uniform behaviour for h5py 2.x and h5py 3.x + + h5py abuses H5T_STRING with ASCII character set + to store `bytes`: dset[()] = b"..." + Therefore an H5T_STRING with ASCII encoding is not decoded by default. + """ + + H5PY_AUTODECODE_NONASCII = int(h5py.version.version.split(".")[0]) < 3 + + def __init__(self, dset, decode_ascii=False): + """ + :param h5py.Dataset dset: + :param bool decode_ascii: + """ + try: + string_info = h5py.h5t.check_string_dtype(dset.dtype) + except AttributeError: + # h5py < 2.10 + try: + idx = dset.id.get_type().get_cset() + except AttributeError: + # Not an H5T_STRING + encoding = None + else: + encoding = ["ascii", "utf-8"][idx] + else: + # h5py >= 2.10 + try: + encoding = string_info.encoding + except AttributeError: + # Not an H5T_STRING + encoding = None + if encoding == "ascii" and not decode_ascii: + encoding = None + if encoding != "ascii" and self.H5PY_AUTODECODE_NONASCII: + # Decoding is already done by the h5py library + encoding = None + if encoding == "ascii": + # ASCII can be decoded as UTF-8 + encoding = "utf-8" + self._encoding = encoding + self._dset = dset + + def __getitem__(self, args): + value = self._dset[args] + if self._encoding: + return h5py_decode_value(value, encoding=self._encoding) + else: + return value + + +class H5pyAttributesReadWrapper: + """Wrapper to handle H5T_STRING decoding on-the-fly when reading + an attribute. Uniform behaviour for h5py 2.x and h5py 3.x + + h5py abuses H5T_STRING with ASCII character set + to store `bytes`: dset[()] = b"..." + Therefore an H5T_STRING with ASCII encoding is not decoded by default. + """ + + H5PY_AUTODECODE = int(h5py.version.version.split(".")[0]) >= 3 + + def __init__(self, attrs, decode_ascii=False): + """ + :param h5py.Dataset dset: + :param bool decode_ascii: + """ + self._attrs = attrs + self._decode_ascii = decode_ascii + + def __getitem__(self, args): + value = self._attrs[args] + + # Get the string encoding (if a string) + try: + dtype = self._attrs.get_id(args).dtype + except AttributeError: + # h5py < 2.10 + attr_id = h5py.h5a.open(self._attrs._id, self._attrs._e(args)) + try: + idx = attr_id.get_type().get_cset() + except AttributeError: + # Not an H5T_STRING + return value + else: + encoding = ["ascii", "utf-8"][idx] + else: + # h5py >= 2.10 + try: + encoding = h5py.h5t.check_string_dtype(dtype).encoding + except AttributeError: + # Not an H5T_STRING + return value + + if self.H5PY_AUTODECODE: + if encoding == "ascii" and not self._decode_ascii: + # Undo decoding by the h5py library + return h5py_encode_value(value, encoding="utf-8") + else: + if encoding == "ascii" and self._decode_ascii: + # Decode ASCII as UTF-8 for consistency + return h5py_decode_value(value, encoding="utf-8") + + # Decoding is already done by the h5py library + return value + + def items(self): + for k in self._attrs.keys(): + yield k, self[k] + + +def h5py_read_dataset(dset, index=tuple(), decode_ascii=False): + """Read data from dataset object. UTF-8 strings will be + decoded while ASCII strings will only be decoded when + `decode_ascii=True`. + + :param h5py.Dataset dset: + :param index: slicing (all by default) + :param bool decode_ascii: + """ + return H5pyDatasetReadWrapper(dset, decode_ascii=decode_ascii)[index] + + +def h5py_read_attribute(attrs, name, decode_ascii=False): + """Read data from attributes. UTF-8 strings will be + decoded while ASCII strings will only be decoded when + `decode_ascii=True`. + + :param h5py.AttributeManager attrs: + :param str name: attribute name + :param bool decode_ascii: + """ + return H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii)[name] + + +def h5py_read_attributes(attrs, decode_ascii=False): + """Read data from attributes. UTF-8 strings will be + decoded while ASCII strings will only be decoded when + `decode_ascii=True`. + + :param h5py.AttributeManager attrs: + :param bool decode_ascii: + """ + return dict(H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii).items()) diff --git a/silx/math/colormap.pyx b/silx/math/colormap.pyx index 2495f3c..2cefe04 100644 --- a/silx/math/colormap.pyx +++ b/silx/math/colormap.pyx @@ -56,6 +56,8 @@ else: # Fallback DEFAULT_NUM_THREADS = 1 # Number of threads to use for the computation (initialized to up to 4) +cdef int USE_OPENMP_THRESHOLD = 1000 +"""OpenMP is not used for arrays with less elements than this threshold""" # Supported data types ctypedef fused data_types: @@ -312,7 +314,7 @@ cdef image_types[:, ::1] compute_cmap( cdef image_types[:, ::1] output cdef double scale, value, normalized_vmin, normalized_vmax cdef int length, nb_channels, nb_colors - cdef int channel, index, lut_index + cdef int channel, index, lut_index, num_threads nb_colors = colors.shape[0] nb_channels = colors.shape[1] @@ -332,8 +334,15 @@ cdef image_types[:, ::1] compute_cmap( else: scale = nb_colors / (normalized_vmax - normalized_vmin) + if length < USE_OPENMP_THRESHOLD: + num_threads = 1 + else: + num_threads = min( + DEFAULT_NUM_THREADS, + int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))) + with nogil: - for index in prange(length, num_threads=DEFAULT_NUM_THREADS): + for index in prange(length, num_threads=num_threads): value = normalization.apply_double( data[index], vmin, vmax) @@ -386,7 +395,7 @@ cdef image_types[:, ::1] compute_cmap_with_lut( cdef image_types[:, ::1] lut cdef int type_min, type_max cdef int nb_channels, length - cdef int channel, index, lut_index + cdef int channel, index, lut_index, num_threads length = data.size nb_channels = colors.shape[1] @@ -412,9 +421,16 @@ cdef image_types[:, ::1] compute_cmap_with_lut( output = numpy.empty((length, nb_channels), dtype=colors_dtype) + if length < USE_OPENMP_THRESHOLD: + num_threads = 1 + else: + num_threads = min( + DEFAULT_NUM_THREADS, + int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))) + with nogil: # Apply LUT - for index in prange(length, num_threads=DEFAULT_NUM_THREADS): + for index in prange(length, num_threads=num_threads): lut_index = data[index] - type_min for channel in range(nb_channels): output[index, channel] = lut[lut_index, channel] diff --git a/silx/math/fft/test/test_fft.py b/silx/math/fft/test/test_fft.py index 14b1243..9ef2fd2 100644 --- a/silx/math/fft/test/test_fft.py +++ b/silx/math/fft/test/test_fft.py @@ -156,6 +156,9 @@ class TestFFT(ParametricTestCase): tol = self.tol[np.dtype(input_data.dtype)] if trdim == "3D": tol *= 10 # Error is relatively high in high dimensions + # It seems that cuda has problems with C2D batched 1D + if trdim == "batched_1D" and backend == "cuda" and mode == "C2C": + tol *= 10 # Python < 3.5 does not want to mix **extra_args with existing kwargs fft_args = { @@ -177,9 +180,10 @@ class TestFFT(ParametricTestCase): res = F.fft(input_data) res_np = F_np.fft(input_data) mae = self.calc_mae(res, res_np) + all_close = np.allclose(res, res_np, atol=tol, rtol=tol), self.assertTrue( - mae < np.abs(input_data.max()) * tol, - "FFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae) + all_close, + "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)" % (mode, trdim, backend, mae, tol) ) # Inverse FFT diff --git a/silx/math/fit/bgtheories.py b/silx/math/fit/bgtheories.py index ccb556e..631c43e 100644 --- a/silx/math/fit/bgtheories.py +++ b/silx/math/fit/bgtheories.py @@ -1,7 +1,7 @@ # coding: utf-8 #/*########################################################################## # -# Copyright (c) 2004-2019 European Synchrotron Radiation Facility +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -281,7 +281,7 @@ def estimate_strip(x, y): """ estimated_par = [CONFIG["StripWidth"], CONFIG["StripIterations"]] - constraints = numpy.zeros((len(estimated_par), 3), numpy.float) + constraints = numpy.zeros((len(estimated_par), 3), numpy.float64) # code = 3: FIXED constraints[0][0] = 3 constraints[1][0] = 3 @@ -295,7 +295,7 @@ def estimate_snip(x, y): set constraints to FIXED. """ estimated_par = [CONFIG["SnipWidth"]] - constraints = numpy.zeros((len(estimated_par), 3), numpy.float) + constraints = numpy.zeros((len(estimated_par), 3), numpy.float64) # code = 3: FIXED constraints[0][0] = 3 return estimated_par, constraints @@ -321,7 +321,7 @@ def estimate_poly(x, y, deg=2): CONFIG["StripWidth"], CONFIG["StripIterations"]) pcoeffs = numpy.polyfit(x, y, deg) - cons = numpy.zeros((deg + 1, 3), numpy.float) + cons = numpy.zeros((deg + 1, 3), numpy.float64) return pcoeffs, cons diff --git a/silx/math/fit/fitmanager.py b/silx/math/fit/fitmanager.py index 2dc63a1..b60e073 100644 --- a/silx/math/fit/fitmanager.py +++ b/silx/math/fit/fitmanager.py @@ -727,12 +727,12 @@ class FitManager(object): :param xmax: Upper value of x values to use for fitting """ if y is None: - self.xdata0 = numpy.array([], numpy.float) - self.ydata0 = numpy.array([], numpy.float) - # self.sigmay0 = numpy.array([], numpy.float) - self.xdata = numpy.array([], numpy.float) - self.ydata = numpy.array([], numpy.float) - # self.sigmay = numpy.array([], numpy.float) + self.xdata0 = numpy.array([], numpy.float64) + self.ydata0 = numpy.array([], numpy.float64) + # self.sigmay0 = numpy.array([], numpy.float64) + self.xdata = numpy.array([], numpy.float64) + self.ydata = numpy.array([], numpy.float64) + # self.sigmay = numpy.array([], numpy.float64) else: self.ydata0 = numpy.array(y) @@ -886,7 +886,7 @@ class FitManager(object): :return: Output of the fit function with ``x`` as input and ``pars`` as fit parameters. """ - result = numpy.zeros(numpy.shape(x), numpy.float) + result = numpy.zeros(numpy.shape(x), numpy.float64) if self.selectedbg is not None: bg_pars_list = self.bgtheories[self.selectedbg].parameters @@ -1036,7 +1036,7 @@ def test(): from . import bgtheories # Create synthetic data with a sum of gaussian functions - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) p = [1000, 100., 250, 255, 690., 45, diff --git a/silx/math/fit/fittheories.py b/silx/math/fit/fittheories.py index f733d1a..6b19a38 100644 --- a/silx/math/fit/fittheories.py +++ b/silx/math/fit/fittheories.py @@ -213,7 +213,7 @@ class FitTheories(object): """ pcoeffs = numpy.polyfit(x, y, n) - constraints = numpy.zeros((n + 1, 3), numpy.float) + constraints = numpy.zeros((n + 1, 3), numpy.float64) return pcoeffs, constraints def estimate_quadratic(self, x, y): @@ -298,7 +298,7 @@ class FitTheories(object): :return: List of peak indices """ # add padding - ysearch = numpy.ones((len(y) + 2 * fwhm,), numpy.float) + ysearch = numpy.ones((len(y) + 2 * fwhm,), numpy.float64) ysearch[0:fwhm] = y[0] ysearch[-1:-fwhm - 1:-1] = y[len(y)-1] ysearch[fwhm:fwhm + len(y)] = y[:] @@ -389,7 +389,7 @@ class FitTheories(object): xw = x yw = y - bg - cons = numpy.zeros((len(param), 3), numpy.float) + cons = numpy.zeros((len(param), 3), numpy.float64) # peak height must be positive cons[0:len(param):3, 0] = CPOSITIVE @@ -405,10 +405,10 @@ class FitTheories(object): shape = [max(1, int(x)) for x in (param[1:len(param):3])] cons[1:len(param):3, 1] = min(xw) * numpy.ones( shape, - numpy.float) + numpy.float64) cons[1:len(param):3, 2] = max(xw) * numpy.ones( shape, - numpy.float) + numpy.float64) # ensure fwhm is positive cons[2:len(param):3, 0] = CPOSITIVE @@ -420,7 +420,7 @@ class FitTheories(object): full_output=True) # set final constraints based on config parameters - cons = numpy.zeros((len(fittedpar), 3), numpy.float) + cons = numpy.zeros((len(fittedpar), 3), numpy.float64) peak_index = 0 for i in range(len(peaks)): # Setup height area constrains @@ -524,7 +524,7 @@ class FitTheories(object): # get the number of found peaks npeaks = len(fittedpar) // 3 estimated_parameters = [] - estimated_constraints = numpy.zeros((4 * npeaks, 3), numpy.float) + estimated_constraints = numpy.zeros((4 * npeaks, 3), numpy.float64) for i in range(npeaks): for j in range(3): estimated_parameters.append(fittedpar[3 * i + j]) @@ -579,7 +579,7 @@ class FitTheories(object): fittedpar, cons = self.estimate_height_position_fwhm(x, y) npeaks = len(fittedpar) // 3 newpar = [] - newcons = numpy.zeros((4 * npeaks, 3), numpy.float) + newcons = numpy.zeros((4 * npeaks, 3), numpy.float64) # find out related parameters proper index if not self.config['NoConstraintsFlag']: if self.config['SameFwhmFlag']: @@ -640,7 +640,7 @@ class FitTheories(object): fittedpar, cons = self.estimate_height_position_fwhm(x, y) npeaks = len(fittedpar) // 3 newpar = [] - newcons = numpy.zeros((5 * npeaks, 3), numpy.float) + newcons = numpy.zeros((5 * npeaks, 3), numpy.float64) # find out related parameters proper index if not self.config['NoConstraintsFlag']: if self.config['SameFwhmFlag']: @@ -741,7 +741,7 @@ class FitTheories(object): fittedpar, cons = self.estimate_height_position_fwhm(x, y) npeaks = len(fittedpar) // 3 newpar = [] - newcons = numpy.zeros((8 * npeaks, 3), numpy.float) + newcons = numpy.zeros((8 * npeaks, 3), numpy.float64) main_peak = 0 # find out related parameters proper index if not self.config['NoConstraintsFlag']: @@ -841,7 +841,7 @@ class FitTheories(object): newcons[8 * i + 7, 1] = self.config['MinStepTailHeightRatio'] newcons[8 * i + 7, 2] = self.config['MaxStepTailHeightRatio'] # if self.config['NoConstraintsFlag'] == 1: - # newcons=numpy.zeros((8*npeaks, 3),numpy.float) + # newcons=numpy.zeros((8*npeaks, 3),numpy.float64) if npeaks > 0: if g_term: if self.config['PositiveHeightAreaFlag']: @@ -931,7 +931,7 @@ class FitTheories(object): self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value # Setup constrains - newcons = numpy.zeros((3, 3), numpy.float) + newcons = numpy.zeros((3, 3), numpy.float64) if not self.config['NoConstraintsFlag']: # Setup height constrains if self.config['PositiveHeightAreaFlag']: @@ -983,7 +983,7 @@ class FitTheories(object): position = (xx[0] + xx[-1]) / 2.0 fwhm = xx[-1] - xx[0] largest = [height, position, fwhm, beamfwhm] - cons = numpy.zeros((4, 3), numpy.float) + cons = numpy.zeros((4, 3), numpy.float64) # Setup constrains if not self.config['NoConstraintsFlag']: # Setup height constrains @@ -1056,7 +1056,7 @@ class FitTheories(object): x[len(x)//2], # center: middle of x range self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value - newcons = numpy.zeros((3, 3), numpy.float) + newcons = numpy.zeros((3, 3), numpy.float64) # Setup constrains if not self.config['NoConstraintsFlag']: # Setup height constraints @@ -1123,7 +1123,7 @@ class FitTheories(object): npeaks = len(peaks) if not npeaks: fittedpar = [] - cons = numpy.zeros((len(fittedpar), 3), numpy.float) + cons = numpy.zeros((len(fittedpar), 3), numpy.float64) return fittedpar, cons fittedpar = [0.0, 0.0, 0.0, 0.0, 0.0] @@ -1153,7 +1153,7 @@ class FitTheories(object): fittedpar[4] = search_fwhm # setup constraints - cons = numpy.zeros((5, 3), numpy.float) + cons = numpy.zeros((5, 3), numpy.float64) cons[0, 0] = CFIXED # the number of gaussians if npeaks == 1: cons[1, 0] = CFIXED # the delta between peaks @@ -1337,7 +1337,7 @@ function, parameters list, configuration function and description. def test(a): from silx.math.fit import fitmanager - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) p = [1500, 100., 50.0, 1500, 700., 50.0] y_synthetic = functions.sum_gauss(x, *p) + 1 diff --git a/silx/math/fit/functions.pyx b/silx/math/fit/functions.pyx index ebbc37b..1f78563 100644 --- a/silx/math/fit/functions.pyx +++ b/silx/math/fit/functions.pyx @@ -1,6 +1,6 @@ # coding: utf-8 #/*########################################################################## -# Copyright (C) 2016-2018 European Synchrotron Radiation Facility +# Copyright (C) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -977,7 +977,7 @@ def periodic_gauss(x, *pars): raise IndexError("No parameters specified. " + "At least 5 parameters are required.") - newpars = numpy.zeros((pars[0], 3), numpy.float) + newpars = numpy.zeros((pars[0], 3), numpy.float64) for i in range(int(pars[0])): newpars[i, 0] = pars[2] newpars[i, 1] = pars[3] + i * pars[1] diff --git a/silx/math/fit/leastsq.py b/silx/math/fit/leastsq.py index 8c87d6b..3df1a35 100644 --- a/silx/math/fit/leastsq.py +++ b/silx/math/fit/leastsq.py @@ -1,7 +1,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2004-2017 European Synchrotron Radiation Facility +# Copyright (c) 2004-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -132,7 +132,7 @@ def leastsq(model, xdata, ydata, p0, sigma=None, calculating the numerical derivatives (for model_deriv=None). Normally the actual step length will be sqrt(epsfcn)*x Original Gefit module was using epsfcn 1.0e-5 while default value - is now numpy.finfo(numpy.float).eps as in scipy + is now numpy.finfo(numpy.float64).eps as in scipy :type epsfcn: *optional*, float :param deltachi: float @@ -205,7 +205,7 @@ def leastsq(model, xdata, ydata, p0, sigma=None, if sigma is not None: sigma = numpy.asarray_chkfinite(sigma) else: - sigma = numpy.ones((ydata.shape), dtype=numpy.float) + sigma = numpy.ones((ydata.shape), dtype=numpy.float64) ydata.shape = -1 sigma.shape = -1 else: @@ -215,7 +215,7 @@ def leastsq(model, xdata, ydata, p0, sigma=None, if sigma is not None: sigma = numpy.asarray(sigma) else: - sigma = numpy.ones((ydata.shape), dtype=numpy.float) + sigma = numpy.ones((ydata.shape), dtype=numpy.float64) sigma.shape = -1 # get rid of NaN in input data idx = numpy.isfinite(ydata) @@ -289,9 +289,9 @@ def leastsq(model, xdata, ydata, p0, sigma=None, nparameters = len(parameters) if epsfcn is None: - epsfcn = numpy.finfo(numpy.float).eps + epsfcn = numpy.finfo(numpy.float64).eps else: - epsfcn = max(epsfcn, numpy.finfo(numpy.float).eps) + epsfcn = max(epsfcn, numpy.finfo(numpy.float64).eps) # check if constraints have been passed as text constrained_fit = False @@ -383,7 +383,7 @@ def leastsq(model, xdata, ydata, p0, sigma=None, newpar = fitparam + deltapar [0] else: newpar = parameters.__copy__() - pwork = numpy.zeros(deltapar.shape, numpy.float) + pwork = numpy.zeros(deltapar.shape, numpy.float64) for i in range(n_free): if constraints is None: pwork [0] [i] = fitparam [i] + deltapar [0] [i] @@ -567,7 +567,7 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None, calculating the numerical derivatives (for model_deriv=None). Normally the actual step length will be sqrt(epsfcn)*x Original Gefit module was using epsfcn 1.0e-10 while default value - is now numpy.finfo(numpy.float).eps as in scipy + is now numpy.finfo(numpy.float64).eps as in scipy :type epsfcn: *optional*, float :param left_derivative: @@ -595,9 +595,9 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None, Sequence with the indices of the original parameters considered in the calculations. """ if epsfcn is None: - epsfcn = numpy.finfo(numpy.float).eps + epsfcn = numpy.finfo(numpy.float64).eps else: - epsfcn = max(epsfcn, numpy.finfo(numpy.float).eps) + epsfcn = max(epsfcn, numpy.finfo(numpy.float64).eps) #nr0, nc = data.shape n_param = len(parameters) if constraints is None: @@ -644,9 +644,9 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None, print("Initial value = %f" % parameters[i]) print("Limits are %f and %f" % (pmin, pmax)) print("Parameter will be kept at its starting value") - fitparam = numpy.array(fitparam, numpy.float) - alpha = numpy.zeros((n_free, n_free), numpy.float) - beta = numpy.zeros((1, n_free), numpy.float) + fitparam = numpy.array(fitparam, numpy.float64) + alpha = numpy.zeros((n_free, n_free), numpy.float64) + beta = numpy.zeros((1, n_free), numpy.float64) #delta = (fitparam + numpy.equal(fitparam, 0.0)) * 0.00001 delta = (fitparam + numpy.equal(fitparam, 0.0)) * numpy.sqrt(epsfcn) nr = y.size @@ -803,7 +803,7 @@ def _get_sigma_parameters(parameters, sigma0, constraints): if constraints is None: return sigma0 n_free = 0 - sigma_par = numpy.zeros(parameters.shape, numpy.float) + sigma_par = numpy.zeros(parameters.shape, numpy.float64) for i in range(len(constraints)): if constraints[i][0] == CFREE: sigma_par [i] = sigma0[n_free] @@ -860,7 +860,7 @@ def main(argv=None): return numpy.exp(x * numpy.less(abs(x), 250)) -\ 1.0 * numpy.greater_equal(abs(x), 250) - xx = numpy.arange(npoints, dtype=numpy.float) + xx = numpy.arange(npoints, dtype=numpy.float64) yy = gauss(xx, *[10.5, 2, 1000.0, 20., 15]) sy = numpy.sqrt(abs(yy)) parameters = [0.0, 1.0, 900.0, 25., 10] diff --git a/silx/math/fit/test/test_fit.py b/silx/math/fit/test/test_fit.py index 372d6cb..3fdf394 100644 --- a/silx/math/fit/test/test_fit.py +++ b/silx/math/fit/test/test_fit.py @@ -1,6 +1,6 @@ # coding: utf-8 # /*########################################################################## -# Copyright (C) 2016 European Synchrotron Radiation Facility +# Copyright (C) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,7 @@ class Test_leastsq(unittest.TestCase): self.my_exp = myexp def gauss(x, *params): - params = numpy.array(params, copy=False, dtype=numpy.float) + params = numpy.array(params, copy=False, dtype=numpy.float64) result = params[0] + params[1] * x for i in range(2, len(params), 3): p = params[i:(i+3)] @@ -69,7 +69,7 @@ class Test_leastsq(unittest.TestCase): def gauss_derivative(x, params, idx): if idx == 0: - return numpy.ones(len(x), numpy.float) + return numpy.ones(len(x), numpy.float64) if idx == 1: return x gaussian_peak = (idx - 2) // 3 @@ -140,7 +140,7 @@ class Test_leastsq(unittest.TestCase): parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300] x = numpy.arange(10000.) y = self.gauss(x, *parameters_actual) - delta = numpy.sqrt(numpy.finfo(numpy.float).eps) + delta = numpy.sqrt(numpy.finfo(numpy.float64).eps) for i in range(len(parameters_actual)): p = parameters_actual * 1 if p[i] == 0: diff --git a/silx/math/fit/test/test_fitmanager.py b/silx/math/fit/test/test_fitmanager.py index 7a643cb..acac242 100644 --- a/silx/math/fit/test/test_fitmanager.py +++ b/silx/math/fit/test/test_fitmanager.py @@ -125,7 +125,7 @@ class TestFitmanager(ParametricTestCase): """Test fit manager on synthetic data using a gaussian function and a linear background""" # Create synthetic data with a sum of gaussian functions - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) p = [1000, 100., 250, 255, 650., 45, @@ -186,7 +186,7 @@ class TestFitmanager(ParametricTestCase): """Test FitManager using a custom fit function defined in an external file and imported with FitManager.loadtheories""" # Create synthetic data with a sum of gaussian functions - x = numpy.arange(100).astype(numpy.float) + x = numpy.arange(100).astype(numpy.float64) # a, b, c are the fit parameters # d is a known scaling parameter that is set using configure() @@ -233,7 +233,7 @@ class TestFitmanager(ParametricTestCase): """Test FitManager using a custom fit function defined in an external file and imported with FitManager.loadtheories (legacy PyMca format)""" # Create synthetic data with a sum of gaussian functions - x = numpy.arange(100).astype(numpy.float) + x = numpy.arange(100).astype(numpy.float64) # a, b, c are the fit parameters # d is a known scaling parameter that is set using configure() @@ -279,7 +279,7 @@ class TestFitmanager(ParametricTestCase): """Test FitManager using a custom fit function imported with FitManager.addtheory""" # Create synthetic data with a sum of gaussian functions - x = numpy.arange(100).astype(numpy.float) + x = numpy.arange(100).astype(numpy.float64) # a, b, c are the fit parameters # d is a known scaling parameter that is set using configure() @@ -369,7 +369,7 @@ class TestFitmanager(ParametricTestCase): for theory_name, theory_fun in (('Step Down', sum_stepdown), ('Step Up', sum_stepup)): # Create synthetic data with a sum of gaussian functions - x = numpy.arange(1000).astype(numpy.float) + x = numpy.arange(1000).astype(numpy.float64) # ('Height', 'Position', 'FWHM') p = [1000, 439, 250] @@ -407,7 +407,7 @@ def cubic(x, a, b, c, d): class TestPolynomials(unittest.TestCase): """Test polynomial fit theories and fit background""" def setUp(self): - self.x = numpy.arange(100).astype(numpy.float) + self.x = numpy.arange(100).astype(numpy.float64) def testQuadraticBg(self): gaussian_params = [100, 45, 8] diff --git a/silx/opencl/backprojection.py b/silx/opencl/backprojection.py index 5a4087b..65a9836 100644 --- a/silx/opencl/backprojection.py +++ b/silx/opencl/backprojection.py @@ -164,9 +164,7 @@ class Backprojection(OpenclProcessing): def _allocate_memory(self): # Host memory self.slice = np.zeros(self.dimrec_shape, dtype=np.float32) - self.is_cpu = False - if self.device.type == "CPU": - self.is_cpu = True + self._use_textures = self.check_textures_availability() # Device memory self.buffers = [ @@ -180,7 +178,7 @@ class Backprojection(OpenclProcessing): self.d_sino = self.cl_mem["d_sino"] # shorthand # Texture memory (if relevant) - if not(self.is_cpu): + if self._use_textures: self._allocate_textures() # Local memory @@ -199,7 +197,14 @@ class Backprojection(OpenclProcessing): self.cl_mem["d_axes"][:] = np.ones(self.num_projs, dtype="f") * self.axis_pos def _init_kernels(self): - OpenclProcessing.compile_kernels(self, self.kernel_files) + compile_options = None + if not(self._use_textures): + compile_options = "-DDONT_USE_TEXTURES" + OpenclProcessing.compile_kernels( + self, + self.kernel_files, + compile_options=compile_options + ) # check that workgroup can actually be (16, 16) self.compiletime_workgroup_size = self.kernels.max_workgroup_size("backproj_cpu_kernel") # Workgroup and ndrange sizes are always the same @@ -209,7 +214,7 @@ class Backprojection(OpenclProcessing): _idivup(int(self.dimrec_shape[0]), 32) * self.wg[1] ) # Prepare arguments for the kernel call - if self.is_cpu: + if not(self._use_textures): d_sino_ref = self.d_sino.data else: d_sino_ref = self.d_sino_tex @@ -242,15 +247,7 @@ class Backprojection(OpenclProcessing): """ Allocate the texture for the sinogram. """ - self.d_sino_tex = pyopencl.Image( - self.ctx, - mf.READ_ONLY | mf.USE_HOST_PTR, - pyopencl.ImageFormat( - pyopencl.channel_order.INTENSITY, - pyopencl.channel_type.FLOAT - ), - hostbuf=np.zeros(self.shape[::-1], dtype=np.float32) - ) + self.d_sino_tex = self.allocate_texture(self.shape) def _init_filter(self, filter_name): """Filter initialization @@ -289,7 +286,7 @@ class Backprojection(OpenclProcessing): sino2 = sino if not(sino.flags["C_CONTIGUOUS"] and sino.dtype == np.float32): sino2 = np.ascontiguousarray(sino, dtype=np.float32) - if self.is_cpu: + if not(self._use_textures): ev = pyopencl.enqueue_copy( self.queue, self.d_sino.data, @@ -309,7 +306,7 @@ class Backprojection(OpenclProcessing): return EventDescription(what, ev) def _transfer_device_to_texture(self, d_sino): - if self.is_cpu: + if not(self._use_textures): if id(self.d_sino) == id(d_sino): return ev = pyopencl.enqueue_copy( @@ -343,7 +340,7 @@ class Backprojection(OpenclProcessing): with self.sem: events.append(self._transfer_to_texture(sino)) # Call the backprojection kernel - if self.is_cpu: + if not(self._use_textures): kernel_to_call = self.kernels.backproj_cpu_kernel else: kernel_to_call = self.kernels.backproj_kernel diff --git a/silx/opencl/common.py b/silx/opencl/common.py index 110d941..002c15d 100644 --- a/silx/opencl/common.py +++ b/silx/opencl/common.py @@ -34,7 +34,7 @@ __author__ = "Jerome Kieffer" __contact__ = "Jerome.Kieffer@ESRF.eu" __license__ = "MIT" __copyright__ = "2012-2017 European Synchrotron Radiation Facility, Grenoble, France" -__date__ = "28/11/2019" +__date__ = "30/11/2020" __status__ = "stable" __all__ = ["ocl", "pyopencl", "mf", "release_cl_buffers", "allocate_cl_buffers", "measure_workgroup_size", "kernel_workgroup_size"] @@ -46,10 +46,8 @@ import numpy from .utils import get_opencl_code - logger = logging.getLogger(__name__) - if os.environ.get("SILX_OPENCL") in ["0", "False"]: logger.info("Use of OpenCL has been disabled from environment variable: SILX_OPENCL=0") pyopencl = None @@ -70,13 +68,13 @@ else: mf = pyopencl.mem_flags if pyopencl is None: + # Define default mem flags class mf(object): WRITE_ONLY = 1 READ_ONLY = 1 READ_WRITE = 1 - FLOP_PER_CORE = {"GPU": 64, # GPU, Fermi at least perform 64 flops per cycle/multicore, G80 were at 24 or 48 ... "CPU": 4, # CPU, at least intel's have 4 operation per cycle "ACC": 8} # ACC: the Xeon-phi (MIC) appears to be able to process 8 Flops per hyperthreaded-core @@ -108,6 +106,7 @@ class Device(object): """ Simple class that contains the structure of an OpenCL device """ + def __init__(self, name="None", dtype=None, version=None, driver_version=None, extensions="", memory=None, available=None, cores=None, frequency=None, flop_core=None, idx=0, workgroup=1): @@ -174,6 +173,7 @@ class Platform(object): """ Simple class that contains the structure of an OpenCL platform """ + def __init__(self, name="None", vendor="None", version=None, extensions=None, idx=0): """ Class containing all descriptions of a platform and all devices description within that platform. @@ -225,6 +225,8 @@ class Platform(object): def _measure_workgroup_size(device_or_context, fast=False): """Mesure the maximal work group size of the given device + DEPRECATED since not perfectly correct ! + :param device_or_context: instance of pyopencl.Device or pyopencl.Context or 2-tuple (platformid,deviceid) :param fast: ask the kernel the valid value, don't probe it @@ -318,7 +320,7 @@ class OpenCL(object): #################################################### extensions = device.extensions if (pypl.vendor == "NVIDIA Corporation") and ('cl_khr_fp64' in extensions): - extensions += ' cl_khr_int64_base_atomics cl_khr_int64_extended_atomics' + extensions += ' cl_khr_int64_base_atomics cl_khr_int64_extended_atomics' try: devtype = pyopencl.device_type.to_string(device.type).upper() except ValueError: @@ -573,6 +575,53 @@ def allocate_cl_buffers(buffers, device=None, context=None): return mem +def allocate_texture(ctx, shape, hostbuf=None, support_1D=False): + """ + Allocate an OpenCL image ("texture"). + + :param ctx: OpenCL context + :param shape: Shape of the image. Note that pyopencl and OpenCL < 1.2 + do not support 1D images, so 1D images are handled as 2D with one row + :param support_1D: force the image to be 1D if the shape has only one dim + """ + if len(shape) == 1 and not(support_1D): + shape = (1,) + shape + return pyopencl.Image( + ctx, + pyopencl.mem_flags.READ_ONLY | pyopencl.mem_flags.USE_HOST_PTR, + pyopencl.ImageFormat( + pyopencl.channel_order.INTENSITY, + pyopencl.channel_type.FLOAT + ), + hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32) + ) + + +def check_textures_availability(ctx): + """ + Check whether textures are supported on the current OpenCL context. + + :param ctx: OpenCL context + """ + try: + dummy_texture = allocate_texture(ctx, (16, 16)) + # Need to further access some attributes (pocl) + dummy_height = dummy_texture.height + textures_available = True + del dummy_texture, dummy_height + except (pyopencl.RuntimeError, pyopencl.LogicError): + textures_available = False + # Nvidia Fermi GPUs (compute capability 2.X) do not support opencl read_imagef + # There is no way to detect this until a kernel is compiled + try: + cc = ctx.devices[0].compute_capability_major_nv + textures_available &= (cc >= 3) + except (pyopencl.LogicError, AttributeError): # probably not a Nvidia GPU + pass + # + return textures_available + + def measure_workgroup_size(device): """Measure the actual size of the workgroup @@ -599,12 +648,25 @@ def measure_workgroup_size(device): return res -def kernel_workgroup_size(program, kernel): - """Extract the compile time maximum workgroup size +def query_kernel_info(program, kernel, what="WORK_GROUP_SIZE"): + """Extract the compile time information from a kernel :param program: OpenCL program :param kernel: kernel or name of the kernel - :return: the maximum acceptable workgroup size for the given kernel + :param what: what is the query about ? + :return: int or 3-int for the workgroup size. + + Possible information available are: + * 'COMPILE_WORK_GROUP_SIZE': Returns the work-group size specified inside the kernel (__attribute__((reqd_work_gr oup_size(X, Y, Z)))) + * 'GLOBAL_WORK_SIZE': maximum global size that can be used to execute a kernel #OCL2.1! + * 'LOCAL_MEM_SIZE': amount of local memory in bytes being used by the kernel + * 'PREFERRED_WORK_GROUP_SIZE_MULTIPLE': preferred multiple of workgroup size for launch. This is a performance hint. + * 'PRIVATE_MEM_SIZE' Returns the minimum amount of private memory, in bytes, used by each workitem in the kernel + * 'WORK_GROUP_SIZE': maximum work-group size that can be used to execute a kernel on a specific device given by device + + Further information on: + https://www.khronos.org/registry/OpenCL/sdk/1.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html + """ assert isinstance(program, pyopencl.Program) if not isinstance(kernel, pyopencl.Kernel): @@ -613,5 +675,15 @@ def kernel_workgroup_size(program, kernel): kernel = program.__getattr__(kernel_name) device = program.devices[0] - query_wg = pyopencl.kernel_work_group_info.WORK_GROUP_SIZE + query_wg = getattr(pyopencl.kernel_work_group_info, what) return kernel.get_work_group_info(query_wg, device) + + +def kernel_workgroup_size(program, kernel): + """Extract the compile time maximum workgroup size + + :param program: OpenCL program + :param kernel: kernel or name of the kernel + :return: the maximum acceptable workgroup size for the given kernel + """ + return query_kernel_info(program, kernel, what="WORK_GROUP_SIZE") diff --git a/silx/opencl/convolution.py b/silx/opencl/convolution.py index 138b985..15ef931 100644 --- a/silx/opencl/convolution.py +++ b/silx/opencl/convolution.py @@ -91,17 +91,8 @@ class Convolution(OpenclProcessing): } extra_opts = extra_options or {} self.extra_options.update(extra_opts) - self.is_cpu = (self.device.type == "CPU") self.use_textures = not(self.extra_options["dont_use_textures"]) - self.use_textures *= not(self.is_cpu) - # Nvidia Fermi GPUs (compute capability 2.X) do not support opencl read_imagef - try: - cc = self.ctx.devices[0].compute_capability_major_nv - self.use_textures *= (cc >= 3) - except cl.LogicError: # probably not a Nvidia GPU - pass - except AttributeError: # probably not a Nvidia GPU - pass + self.use_textures &= self.check_textures_availability() def _get_dimensions(self, shape, kernel): self.shape = shape diff --git a/silx/opencl/processing.py b/silx/opencl/processing.py index 6b475b9..470b141 100644 --- a/silx/opencl/processing.py +++ b/silx/opencl/processing.py @@ -36,26 +36,23 @@ Common OpenCL abstract base classe for different processing from __future__ import absolute_import, print_function, division - __author__ = "Jerome Kieffer" __contact__ = "Jerome.Kieffer@ESRF.eu" __license__ = "MIT" __copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" -__date__ = "05/08/2019" +__date__ = "04/12/2020" __status__ = "stable" - import os import logging import gc from collections import namedtuple import numpy import threading -from .common import ocl, pyopencl, release_cl_buffers, kernel_workgroup_size +from .common import ocl, pyopencl, release_cl_buffers, query_kernel_info, allocate_texture, check_textures_availability from .utils import concatenate_cl_kernel import platform - BufferDescription = namedtuple("BufferDescription", ["name", "size", "dtype", "flags"]) EventDescription = namedtuple("EventDescription", ["name", "event"]) @@ -85,13 +82,22 @@ class KernelContainer(object): return self.__dict__.get(name) def max_workgroup_size(self, kernel_name): - "Retrieve the compile time max_workgroup_size for a given kernel" + "Retrieve the compile time WORK_GROUP_SIZE for a given kernel" if isinstance(kernel_name, pyopencl.Kernel): kernel = kernel_name else: kernel = self.get_kernel(kernel_name) - return kernel_workgroup_size(self._program, kernel) + return query_kernel_info(self._program, kernel, "WORK_GROUP_SIZE") + + def min_workgroup_size(self, kernel_name): + "Retrieve the compile time PREFERRED_WORK_GROUP_SIZE_MULTIPLE for a given kernel" + if isinstance(kernel_name, pyopencl.Kernel): + kernel = kernel_name + else: + kernel = self.get_kernel(kernel_name) + + return query_kernel_info(self._program, kernel, "PREFERRED_WORK_GROUP_SIZE_MULTIPLE") class OpenclProcessing(object): @@ -149,6 +155,9 @@ class OpenclProcessing(object): self.program = None self.kernels = None + def check_textures_availability(self): + return check_textures_availability(self.ctx) + def __del__(self): """Destructor: release all buffers and programs """ @@ -156,8 +165,10 @@ class OpenclProcessing(object): self.reset_log() self.free_kernels() self.free_buffers() - except Exception: - pass + if self.queue is not None: + self.queue.finish() + except Exception as err: + logger.warning("%s: %s", type(err), err) self.queue = None self.device = None self.ctx = None @@ -287,6 +298,8 @@ class OpenclProcessing(object): if bool(value) != self.profile: with self.sem: self.profile = bool(value) + if self.queue is not None: + self.queue.finish() if self.profile: self.queue = pyopencl.CommandQueue(self.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE) @@ -304,24 +317,7 @@ class OpenclProcessing(object): self.events.append(EventDescription(desc, event)) def allocate_texture(self, shape, hostbuf=None, support_1D=False): - """ - Allocate an OpenCL image ("texture"). - - :param shape: Shape of the image. Note that pyopencl and OpenCL < 1.2 - do not support 1D images, so 1D images are handled as 2D with one row - :param support_1D: force the image to be 1D if the shape has only one dim - """ - if len(shape) == 1 and not(support_1D): - shape = (1,) + shape - return pyopencl.Image( - self.ctx, - pyopencl.mem_flags.READ_ONLY | pyopencl.mem_flags.USE_HOST_PTR, - pyopencl.ImageFormat( - pyopencl.channel_order.INTENSITY, - pyopencl.channel_type.FLOAT - ), - hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32) - ) + return allocate_texture(self.ctx, shape, hostbuf=hostbuf, support_1D=support_1D) def transfer_to_texture(self, arr, tex_ref): """ @@ -336,10 +332,10 @@ class OpenclProcessing(object): if ndim == 1: # pyopencl and OpenCL < 1.2 do not support image1d_t # force 2D with one row in this case - #~ ndim = 2 + # ~ ndim = 2 shp = (1,) + shp copy_kwargs = {"origin":(0,) * ndim, "region": shp[::-1]} - if not(isinstance(arr, numpy.ndarray)): # assuming pyopencl.array.Array + if not(isinstance(arr, numpy.ndarray)): # assuming pyopencl.array.Array # D->D copy copy_args[2] = arr.data copy_kwargs["offset"] = 0 diff --git a/silx/opencl/projection.py b/silx/opencl/projection.py index da8752f..c02faf6 100644 --- a/silx/opencl/projection.py +++ b/silx/opencl/projection.py @@ -2,7 +2,7 @@ # coding: utf-8 # /*########################################################################## # -# Copyright (c) 2016 European Synchrotron Radiation Facility +# Copyright (c) 2016-2020 European Synchrotron Radiation Facility # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -115,7 +115,7 @@ class Projection(OpenclProcessing): self.offset_x = -np.float32((self.shape[1] - 1) / 2. - self.axis_pos) # TODO: custom self.offset_y = -np.float32((self.shape[0] - 1) / 2. - self.axis_pos) # TODO: custom # Reset axis_pos once offset are computed - self.axis_pos0 = np.float((self.shape[1] - 1) / 2.) + self.axis_pos0 = np.float64((self.shape[1] - 1) / 2.) # Workgroup, ndrange and shared size self.dimgrid_x = _idivup(self.dwidth, 16) @@ -129,9 +129,7 @@ class Projection(OpenclProcessing): int(self.dimgrid_y) * self.wg[1] # int(): pyopencl <= 2015.1 ) - self.is_cpu = False - if self.device.type == "CPU": - self.is_cpu = True + self._use_textures = self.check_textures_availability() # Allocate memory self.buffers = [ @@ -150,14 +148,14 @@ class Projection(OpenclProcessing): ) self._tmp_extended_img = np.zeros((self.shape[0] + 2, self.shape[1] + 2), dtype=np.float32) - if self.is_cpu: + if not(self._use_textures): self.allocate_slice() else: self.allocate_textures() self.allocate_buffers() self._ex_sino = np.zeros((self._dimrecy, self._dimrecx), dtype=np.float32) - if self.is_cpu: + if not(self._use_textures): self.cl_mem["d_slice"].fill(0.) # enqueue_fill_buffer has issues if opencl 1.2 is not present # ~ pyopencl.enqueue_fill_buffer( @@ -182,7 +180,14 @@ class Projection(OpenclProcessing): # Shorthands self._d_sino = self.cl_mem["_d_sino"] - OpenclProcessing.compile_kernels(self, self.kernel_files) + compile_options = None + if not(self._use_textures): + compile_options = "-DDONT_USE_TEXTURES" + OpenclProcessing.compile_kernels( + self, + self.kernel_files, + compile_options=compile_options + ) # check that workgroup can actually be (16, 16) self.compiletime_workgroup_size = self.kernels.max_workgroup_size("forward_kernel_cpu") @@ -194,7 +199,7 @@ class Projection(OpenclProcessing): pyopencl.enqueue_copy(self.queue, self.cl_mem["d_angles"], angles2) def allocate_slice(self): - ary = parray.zeros(self.queue, (self.shape[1] + 2, self.shape[1] + 2), np.float32) + ary = parray.empty(self.queue, (self.shape[1] + 2, self.shape[1] + 2), np.float32) ary.fill(0) self.add_to_cl_mem({"d_slice": ary}) @@ -212,7 +217,7 @@ class Projection(OpenclProcessing): image2 = image if not(image.flags["C_CONTIGUOUS"] and image.dtype == np.float32): image2 = np.ascontiguousarray(image) - if self.is_cpu: + if not(self._use_textures): # TODO: create NoneEvent return self.transfer_to_slice(image2) # ~ return pyopencl.enqueue_copy( @@ -232,7 +237,7 @@ class Projection(OpenclProcessing): ) def transfer_device_to_texture(self, d_image): - if self.is_cpu: + if not(self._use_textures): # TODO this copy should not be necessary return self.cpy2d_to_slice(d_image) else: @@ -355,14 +360,14 @@ class Projection(OpenclProcessing): assert image.ndim == 2, "Treat only 2D images" assert image.shape[0] == self.shape[0], "image shape is OK" assert image.shape[1] == self.shape[1], "image shape is OK" - if not(self.is_cpu): + if self._use_textures: self.transfer_to_texture(image) slice_ref = self.d_image_tex else: self.transfer_to_slice(image) slice_ref = self.cl_mem["d_slice"].data else: - if self.is_cpu: + if not(self._use_textures): slice_ref = self.cl_mem["d_slice"].data else: slice_ref = self.d_image_tex @@ -388,7 +393,7 @@ class Projection(OpenclProcessing): ) # Call the kernel - if self.is_cpu: + if not(self._use_textures): event_pj = self.kernels.forward_kernel_cpu( self.queue, self.ndrange, diff --git a/silx/opencl/test/test_addition.py b/silx/opencl/test/test_addition.py index 49cc0b4..19dfdf0 100644 --- a/silx/opencl/test/test_addition.py +++ b/silx/opencl/test/test_addition.py @@ -29,19 +29,17 @@ Simple test of an addition """ -from __future__ import division, print_function - __authors__ = ["Henri Payno, Jérôme Kieffer"] __contact__ = "jerome.kieffer@esrf.eu" __license__ = "MIT" __copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France" -__date__ = "01/08/2019" +__date__ = "30/11/2020" import logging import numpy import unittest -from ..common import ocl, _measure_workgroup_size +from ..common import ocl, _measure_workgroup_size, query_kernel_info if ocl: import pyopencl import pyopencl.array @@ -116,7 +114,7 @@ class TestAddition(unittest.TestCase): @unittest.skipUnless(ocl, "pyopencl is missing") def test_measurement(self): """ - tests that all devices are working properly ... + tests that all devices are working properly ... lengthy and error prone """ for platform in ocl.platforms: for did, device in enumerate(platform.devices): @@ -124,11 +122,31 @@ class TestAddition(unittest.TestCase): self.assertEqual(meas, device.max_work_group_size, "Workgroup size for %s/%s: %s == %s" % (platform, device, meas, device.max_work_group_size)) + @unittest.skipUnless(ocl, "pyopencl is missing") + def test_query(self): + """ + tests that all devices are working properly ... lengthy and error prone + """ + for what in ("COMPILE_WORK_GROUP_SIZE", + "LOCAL_MEM_SIZE", + "PREFERRED_WORK_GROUP_SIZE_MULTIPLE", + "PRIVATE_MEM_SIZE", + "WORK_GROUP_SIZE"): + logger.info("%s: %s", what, query_kernel_info(program=self.program, kernel="addition", what=what)) + + # Not all ICD work properly .... + #self.assertEqual(3, len(query_kernel_info(program=self.program, kernel="addition", what="COMPILE_WORK_GROUP_SIZE")), "3D kernel") + + min_wg = query_kernel_info(program=self.program, kernel="addition", what="PREFERRED_WORK_GROUP_SIZE_MULTIPLE") + max_wg = query_kernel_info(program=self.program, kernel="addition", what="WORK_GROUP_SIZE") + self.assertEqual(max_wg % min_wg, 0, msg="max_wg is a multiple of min_wg") + def suite(): testSuite = unittest.TestSuite() testSuite.addTest(TestAddition("test_add")) # testSuite.addTest(TestAddition("test_measurement")) + testSuite.addTest(TestAddition("test_query")) return testSuite diff --git a/silx/opencl/test/test_backprojection.py b/silx/opencl/test/test_backprojection.py index b2f2070..9dfdd3a 100644 --- a/silx/opencl/test/test_backprojection.py +++ b/silx/opencl/test/test_backprojection.py @@ -96,8 +96,9 @@ class TestFBP(unittest.TestCase): # Therefore, we cannot expect results to be the "same" (up to float32 # numerical error) self.tol = 5e-2 - if self.fbp.is_cpu: + if not(self.fbp._use_textures) or self.fbp.device.type == "CPU": # Precision is less when using CPU + # (either CPU textures or "manual" linear interpolation) self.tol *= 2 def tearDown(self): diff --git a/silx/opencl/test/test_convolution.py b/silx/opencl/test/test_convolution.py index 27cb8a9..7bceb0d 100644 --- a/silx/opencl/test/test_convolution.py +++ b/silx/opencl/test/test_convolution.py @@ -41,15 +41,18 @@ from itertools import product import numpy as np from silx.utils.testutils import parameterize from silx.image.utils import gaussian_kernel + try: from scipy.ndimage import convolve, convolve1d from scipy.misc import ascent + scipy_convolve = convolve scipy_convolve1d = convolve1d except ImportError: scipy_convolve = None import unittest -from ..common import ocl +from ..common import ocl, check_textures_availability + if ocl: import pyopencl as cl import pyopencl.array as parray @@ -59,7 +62,6 @@ logger = logging.getLogger(__name__) @unittest.skipUnless(ocl and scipy_convolve, "PyOpenCl/scipy is missing") class TestConvolution(unittest.TestCase): - @classmethod def setUpClass(cls): super(TestConvolution, cls).setUpClass() @@ -67,7 +69,7 @@ class TestConvolution(unittest.TestCase): cls.data1d = cls.image[0] cls.data2d = cls.image cls.data3d = np.tile(cls.image[224:-224, 224:-224], (62, 1, 1)) - cls.kernel1d = gaussian_kernel(1.) + cls.kernel1d = gaussian_kernel(1.0) cls.kernel2d = np.outer(cls.kernel1d, cls.kernel1d) cls.kernel3d = np.multiply.outer(cls.kernel2d, cls.kernel1d) cls.ctx = ocl.create_context() @@ -97,7 +99,7 @@ class TestConvolution(unittest.TestCase): ) return errmsg - def __init__(self, methodName='runTest', param=None): + def __init__(self, methodName="runTest", param=None): unittest.TestCase.__init__(self, methodName) self.param = param self.mode = param["boundary_handling"] @@ -107,32 +109,27 @@ class TestConvolution(unittest.TestCase): use_textures=%s, input_device=%s, output_device=%s """ % ( - self.mode, param["use_textures"], - param["input_on_device"], param["output_on_device"] + self.mode, + param["use_textures"], + param["input_on_device"], + param["output_on_device"], ) ) def instantiate_convol(self, shape, kernel, axes=None): - def is_fermi_device(dev): - try: - res = (dev.compute_capability_major_nv < 3) - except cl.LogicError: - res = False - except AttributeError: - res = False - return res - if (self.mode == "constant") and ( - not(self.param["use_textures"]) - or (self.ctx.devices[0].type == cl._cl.device_type.CPU) - or (is_fermi_device(self.ctx.devices[0])) + if self.mode == "constant": + if not (self.param["use_textures"]) or ( + self.param["use_textures"] + and not (check_textures_availability(self.ctx)) ): self.skipTest("mode=constant not implemented without textures") C = Convolution( - shape, kernel, + shape, + kernel, mode=self.mode, ctx=self.ctx, axes=axes, - extra_options={"dont_use_textures": not(self.param["use_textures"])} + extra_options={"dont_use_textures": not (self.param["use_textures"])}, ) return C @@ -142,13 +139,9 @@ class TestConvolution(unittest.TestCase): "test_separable_2D": (2, 1), "test_separable_3D": (3, 1), "test_nonseparable_2D": (2, 2), - "test_nonseparable_3D": (3, 3), - } - dim_data = { - 1: self.data1d, - 2: self.data2d, - 3: self.data3d + "test_nonseparable_3D": (3, 3), } + dim_data = {1: self.data1d, 2: self.data2d, 3: self.data3d} dim_kernel = { 1: self.kernel1d, 2: self.kernel2d, @@ -159,24 +152,26 @@ class TestConvolution(unittest.TestCase): def get_reference_function(self, test_name): ref_func = { - "test_1D": - lambda x, y : scipy_convolve1d(x, y, mode=self.mode), - "test_separable_2D": - lambda x, y : scipy_convolve1d( - scipy_convolve1d(x, y, mode=self.mode, axis=1), - y, mode=self.mode, axis=0 - ), - "test_separable_3D": - lambda x, y: scipy_convolve1d( - scipy_convolve1d( - scipy_convolve1d(x, y, mode=self.mode, axis=2), - y, mode=self.mode, axis=1), - y, mode=self.mode, axis=0 + "test_1D": lambda x, y: scipy_convolve1d(x, y, mode=self.mode), + "test_separable_2D": lambda x, y: scipy_convolve1d( + scipy_convolve1d(x, y, mode=self.mode, axis=1), + y, + mode=self.mode, + axis=0, + ), + "test_separable_3D": lambda x, y: scipy_convolve1d( + scipy_convolve1d( + scipy_convolve1d(x, y, mode=self.mode, axis=2), + y, + mode=self.mode, + axis=1, ), - "test_nonseparable_2D": - lambda x, y: scipy_convolve(x, y, mode=self.mode), - "test_nonseparable_3D": - lambda x, y : scipy_convolve(x, y, mode=self.mode), + y, + mode=self.mode, + axis=0, + ), + "test_nonseparable_2D": lambda x, y: scipy_convolve(x, y, mode=self.mode), + "test_nonseparable_3D": lambda x, y: scipy_convolve(x, y, mode=self.mode), } return ref_func[test_name] @@ -226,8 +221,8 @@ class TestConvolution(unittest.TestCase): data = self.data3d kernel = self.kernel2d conv = self.instantiate_convol(data.shape, kernel, axes=(0,)) - res = conv(data) # 3D - ref = scipy_convolve(data[0], kernel, mode=self.mode) # 2D + res = conv(data) # 3D + ref = scipy_convolve(data[0], kernel, mode=self.mode) # 2D std = np.std(res, axis=0) std_max = np.max(np.abs(std)) @@ -244,12 +239,9 @@ def test_convolution(): output_on_device_ = [True, False] testSuite = unittest.TestSuite() - param_vals = list(product( - boundary_handling_, - use_textures_, - input_on_device_, - output_on_device_ - )) + param_vals = list( + product(boundary_handling_, use_textures_, input_on_device_, output_on_device_) + ) for boundary_handling, use_textures, input_dev, output_dev in param_vals: testcase = parameterize( TestConvolution, @@ -258,17 +250,16 @@ def test_convolution(): "input_on_device": input_dev, "output_on_device": output_dev, "use_textures": use_textures, - } + }, ) testSuite.addTest(testcase) return testSuite - def suite(): testSuite = test_convolution() return testSuite -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(defaultTest="suite") diff --git a/silx/resources/gui/icons/add.png b/silx/resources/gui/icons/add.png new file mode 100644 index 0000000..80c6400 Binary files /dev/null and b/silx/resources/gui/icons/add.png differ diff --git a/silx/resources/gui/icons/add.svg b/silx/resources/gui/icons/add.svg new file mode 100644 index 0000000..19c1a6d --- /dev/null +++ b/silx/resources/gui/icons/add.svg @@ -0,0 +1,2 @@ + +image/svg+xml diff --git a/silx/resources/gui/icons/backend-opengl.png b/silx/resources/gui/icons/backend-opengl.png new file mode 100644 index 0000000..ff81f64 Binary files /dev/null and b/silx/resources/gui/icons/backend-opengl.png differ diff --git a/silx/resources/gui/icons/backend-opengl.svg b/silx/resources/gui/icons/backend-opengl.svg new file mode 100644 index 0000000..41d79b8 --- /dev/null +++ b/silx/resources/gui/icons/backend-opengl.svg @@ -0,0 +1,18 @@ + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/silx/resources/gui/icons/rm.png b/silx/resources/gui/icons/rm.png new file mode 100644 index 0000000..ecff08b Binary files /dev/null and b/silx/resources/gui/icons/rm.png differ diff --git a/silx/resources/gui/icons/rm.svg b/silx/resources/gui/icons/rm.svg new file mode 100644 index 0000000..7cc515e --- /dev/null +++ b/silx/resources/gui/icons/rm.svg @@ -0,0 +1,2 @@ + +image/svg+xml diff --git a/silx/resources/opencl/backproj.cl b/silx/resources/opencl/backproj.cl index 6fadc2c..da15131 100644 --- a/silx/resources/opencl/backproj.cl +++ b/silx/resources/opencl/backproj.cl @@ -35,7 +35,7 @@ /************************ GPU VERSION (with textures) **************************/ /*******************************************************************************/ - +#ifndef DONT_USE_TEXTURES kernel void backproj_kernel( int num_proj, int num_bins, @@ -55,11 +55,6 @@ kernel void backproj_kernel( const int tidy = get_local_id(1); //threadIdx.y; const int bidy = get_group_id(1); //blockIdx.y; - //~ local float shared[768]; - //~ float * sh_sin = shared; - //~ float * sh_cos = shared+256; - //~ float * sh_axis = sh_cos+256; - local float sh_cos[256]; local float sh_sin[256]; local float sh_axis[256]; @@ -107,7 +102,7 @@ kernel void backproj_kernel( d_SLICE[ 32*get_num_groups(0)*(bidy*32+tidy*2+0) + bidx*32 + tidx*2 + 1] = res2; d_SLICE[ 32*get_num_groups(0)*(bidy*32+tidy*2+1) + bidx*32 + tidx*2 + 1] = res3; } - +#endif @@ -134,7 +129,7 @@ static float linear_interpolation(float2 vals, { if (xm == xp) return vals.s0; - else + else return (vals.s0 * (xp - x)) + (vals.s1 * (x - xm)); } @@ -197,280 +192,36 @@ kernel void backproj_cpu_kernel( h1 = (acorr05 + (bx00+0)*pcos - (by00+1)*psin); h2 = (acorr05 + (bx00+1)*pcos - (by00+0)*psin); h3 = (acorr05 + (bx00+1)*pcos - (by00+1)*psin); - - - float x; - int ym, xm, xp; - ym = proj; - float2 vals; - - if(h0>=0 && h0=0 && h1=0 && h2=0 && h3= Nx) adj_coords.s1 = Nx - 1; - if (adj_coords.s2 < 0) adj_coords.s2 = 0; - if (adj_coords.s3 >= Ny) adj_coords.s3 = Ny -1; - if (adj_coords.s0 >= Nx) adj_coords.s0 = Nx - 1; - if (adj_coords.s2 >= Ny) adj_coords.s2 = Ny -1; - // Interp - val = adj_vals.s1*(adj_coords.s1-x)*(y-adj_coords.s2) - + adj_vals.s2 *(x-adj_coords.s0)*(y-adj_coords.s2) - + adj_vals.s0 *(adj_coords.s1-x)*(adj_coords.s3-y) - + adj_vals.s3 *(x-adj_coords.s0)*(adj_coords.s3-y); - - } - return val; -} -*/ - - -/* -__kernel void backproj_cpu_kernel_good( - int num_proj, - int num_bins, - float axis_position, - __global float *d_SLICE, - __global float* d_sino, - float gpu_offset_x, - float gpu_offset_y, - __global float * d_cos_s, // precalculated cos(theta[i]) - __global float * d_sin_s, // precalculated sin(theta[i]) - __global float * d_axis_s, // array of axis positions (n_projs) - __local float* shared2) // 768B of local mem -{ - const int tidx = get_local_id(0); //threadIdx.x; - const int bidx = get_group_id(0); //blockIdx.x; - const int tidy = get_local_id(1); //threadIdx.y; - const int bidy = get_group_id(1); //blockIdx.y; - - //~ __local float shared[768]; - //~ float * sh_sin = shared; - //~ float * sh_cos = shared+256; - //~ float * sh_axis = sh_cos+256; - __local float sh_cos[256]; - __local float sh_sin[256]; - __local float sh_axis[256]; - - float pcos, psin; - float h0, h1, h2, h3; - const float apos_off_x= gpu_offset_x - axis_position ; - const float apos_off_y= gpu_offset_y - axis_position ; - float acorr05; - float res0 = 0, res1 = 0, res2 = 0, res3 = 0; - const float bx00 = (32 * bidx + 2 * tidx + 0 + apos_off_x ) ; - const float by00 = (32 * bidy + 2 * tidy + 0 + apos_off_y ) ; + float x; + int ym, xm, xp; + ym = proj; + float2 vals; - int read=0; - for(int proj=0; proj=read) { - barrier(CLK_LOCAL_MEM_FENCE); - int ip = tidy*16+tidx; - if( read+ip < num_proj) { - sh_cos [ip] = d_cos_s[read+ip] ; - sh_sin [ip] = d_sin_s[read+ip] ; - sh_axis[ip] = d_axis_s[read+ip] ; - } - read=read+256; // 256=16*16 block size - barrier(CLK_LOCAL_MEM_FENCE); + if(h0>=0 && h0=0 && h0=0 && h1=0 && h2=0 && h3